xiaoanyu123 commited on
Commit
e2bcd96
·
verified ·
1 Parent(s): 5e7c231

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/__init__.py +57 -0
  2. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1118 -0
  3. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1309 -0
  4. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1023 -0
  5. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +1065 -0
  6. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1353 -0
  7. pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_output.py +24 -0
  8. pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/__init__.py +51 -0
  9. pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/pipeline_audioldm.py +558 -0
  10. pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/__init__.py +50 -0
  11. pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/modeling_audioldm2.py +1475 -0
  12. pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1104 -0
  13. pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/__init__.py +48 -0
  14. pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +677 -0
  15. pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/__init__.py +20 -0
  16. pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/blip_image_processing.py +318 -0
  17. pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_blip2.py +639 -0
  18. pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +223 -0
  19. pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +361 -0
  20. pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/__init__.py +48 -0
  21. pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_bria.py +729 -0
  22. pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_output.py +21 -0
  23. pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/__init__.py +49 -0
  24. pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  25. pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  26. pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_output.py +21 -0
  27. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/__init__.py +54 -0
  28. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +789 -0
  29. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +842 -0
  30. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +903 -0
  31. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +868 -0
  32. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  33. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/__init__.py +47 -0
  34. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +682 -0
  35. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  36. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/__init__.py +49 -0
  37. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4.py +685 -0
  38. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  39. pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  40. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/__init__.py +49 -0
  41. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/consisid_utils.py +357 -0
  42. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  43. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_output.py +20 -0
  44. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/__init__.py +24 -0
  45. pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +286 -0
  46. pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/__init__.py +86 -0
  47. pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/multicontrolnet.py +12 -0
  48. pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet.py +1366 -0
  49. pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +427 -0
  50. pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1338 -0
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
15
+
16
+ try:
17
+ if not (is_transformers_available() and is_torch_available()):
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ from ...utils import dummy_torch_and_transformers_objects
21
+
22
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23
+ else:
24
+ _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
25
+ _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
26
+ _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
27
+ _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
28
+ _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
29
+ _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
30
+
31
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
32
+ try:
33
+ if not (is_transformers_available() and is_torch_available()):
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ from ...utils.dummy_torch_and_transformers_objects import *
37
+
38
+ else:
39
+ from .pipeline_animatediff import AnimateDiffPipeline
40
+ from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
41
+ from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
42
+ from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
43
+ from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
44
+ from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
45
+ from .pipeline_output import AnimateDiffPipelineOutput
46
+
47
+ else:
48
+ import sys
49
+
50
+ sys.modules[__name__] = _LazyModule(
51
+ __name__,
52
+ globals()["__file__"],
53
+ _import_structure,
54
+ module_spec=__spec__,
55
+ )
56
+ for name, value in _dummy_objects.items():
57
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
+
22
+ from ...image_processor import PipelineImageInput
23
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
+ from ...models import (
25
+ AutoencoderKL,
26
+ ControlNetModel,
27
+ ImageProjection,
28
+ MultiControlNetModel,
29
+ UNet2DConditionModel,
30
+ UNetMotionModel,
31
+ )
32
+ from ...models.lora import adjust_lora_scale_text_encoder
33
+ from ...models.unets.unet_motion_model import MotionAdapter
34
+ from ...schedulers import KarrasDiffusionSchedulers
35
+ from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
36
+ from ...utils.torch_utils import is_compiled_module, randn_tensor
37
+ from ...video_processor import VideoProcessor
38
+ from ..free_init_utils import FreeInitMixin
39
+ from ..free_noise_utils import AnimateDiffFreeNoiseMixin
40
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
+ from .pipeline_output import AnimateDiffPipelineOutput
42
+
43
+
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import (
59
+ ... AnimateDiffControlNetPipeline,
60
+ ... AutoencoderKL,
61
+ ... ControlNetModel,
62
+ ... MotionAdapter,
63
+ ... LCMScheduler,
64
+ ... )
65
+ >>> from diffusers.utils import export_to_gif, load_video
66
+
67
+ >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet
68
+ >>> # HF maintains just the right package for it: `pip install controlnet_aux`
69
+ >>> from controlnet_aux.processor import ZoeDetector
70
+
71
+ >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
72
+ >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
73
+ >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
74
+
75
+ >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
76
+ >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
77
+
78
+ >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
79
+ >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
80
+ ... "SG161222/Realistic_Vision_V5.1_noVAE",
81
+ ... motion_adapter=motion_adapter,
82
+ ... controlnet=controlnet,
83
+ ... vae=vae,
84
+ ... ).to(device="cuda", dtype=torch.float16)
85
+ >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
86
+ >>> pipe.load_lora_weights(
87
+ ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
88
+ ... )
89
+ >>> pipe.set_adapters(["lcm-lora"], [0.8])
90
+
91
+ >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
92
+ >>> video = load_video(
93
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
94
+ ... )
95
+ >>> conditioning_frames = []
96
+
97
+ >>> with pipe.progress_bar(total=len(video)) as progress_bar:
98
+ ... for frame in video:
99
+ ... conditioning_frames.append(depth_detector(frame))
100
+ ... progress_bar.update()
101
+
102
+ >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
103
+ >>> negative_prompt = "bad quality, worst quality"
104
+
105
+ >>> video = pipe(
106
+ ... prompt=prompt,
107
+ ... negative_prompt=negative_prompt,
108
+ ... num_frames=len(video),
109
+ ... num_inference_steps=10,
110
+ ... guidance_scale=2.0,
111
+ ... conditioning_frames=conditioning_frames,
112
+ ... generator=torch.Generator().manual_seed(42),
113
+ ... ).frames[0]
114
+
115
+ >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8)
116
+ ```
117
+ """
118
+
119
+
120
+ class AnimateDiffControlNetPipeline(
121
+ DiffusionPipeline,
122
+ StableDiffusionMixin,
123
+ TextualInversionLoaderMixin,
124
+ IPAdapterMixin,
125
+ StableDiffusionLoraLoaderMixin,
126
+ FreeInitMixin,
127
+ AnimateDiffFreeNoiseMixin,
128
+ FromSingleFileMixin,
129
+ ):
130
+ r"""
131
+ Pipeline for text-to-video generation with ControlNet guidance.
132
+
133
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
134
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
135
+
136
+ The pipeline also inherits the following loading methods:
137
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
138
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
139
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
140
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
141
+
142
+ Args:
143
+ vae ([`AutoencoderKL`]):
144
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
145
+ text_encoder ([`CLIPTextModel`]):
146
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
147
+ tokenizer (`CLIPTokenizer`):
148
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
149
+ unet ([`UNet2DConditionModel`]):
150
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
151
+ motion_adapter ([`MotionAdapter`]):
152
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
153
+ scheduler ([`SchedulerMixin`]):
154
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
155
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
156
+ """
157
+
158
+ model_cpu_offload_seq = "text_encoder->unet->vae"
159
+ _optional_components = ["feature_extractor", "image_encoder"]
160
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
161
+
162
+ def __init__(
163
+ self,
164
+ vae: AutoencoderKL,
165
+ text_encoder: CLIPTextModel,
166
+ tokenizer: CLIPTokenizer,
167
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
168
+ motion_adapter: MotionAdapter,
169
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
170
+ scheduler: KarrasDiffusionSchedulers,
171
+ feature_extractor: Optional[CLIPImageProcessor] = None,
172
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
173
+ ):
174
+ super().__init__()
175
+ if isinstance(unet, UNet2DConditionModel):
176
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
177
+
178
+ if isinstance(controlnet, (list, tuple)):
179
+ controlnet = MultiControlNetModel(controlnet)
180
+
181
+ self.register_modules(
182
+ vae=vae,
183
+ text_encoder=text_encoder,
184
+ tokenizer=tokenizer,
185
+ unet=unet,
186
+ motion_adapter=motion_adapter,
187
+ controlnet=controlnet,
188
+ scheduler=scheduler,
189
+ feature_extractor=feature_extractor,
190
+ image_encoder=image_encoder,
191
+ )
192
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
193
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
194
+ self.control_video_processor = VideoProcessor(
195
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
196
+ )
197
+
198
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
199
+ def encode_prompt(
200
+ self,
201
+ prompt,
202
+ device,
203
+ num_images_per_prompt,
204
+ do_classifier_free_guidance,
205
+ negative_prompt=None,
206
+ prompt_embeds: Optional[torch.Tensor] = None,
207
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
208
+ lora_scale: Optional[float] = None,
209
+ clip_skip: Optional[int] = None,
210
+ ):
211
+ r"""
212
+ Encodes the prompt into text encoder hidden states.
213
+
214
+ Args:
215
+ prompt (`str` or `List[str]`, *optional*):
216
+ prompt to be encoded
217
+ device: (`torch.device`):
218
+ torch device
219
+ num_images_per_prompt (`int`):
220
+ number of images that should be generated per prompt
221
+ do_classifier_free_guidance (`bool`):
222
+ whether to use classifier free guidance or not
223
+ negative_prompt (`str` or `List[str]`, *optional*):
224
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
225
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
226
+ less than `1`).
227
+ prompt_embeds (`torch.Tensor`, *optional*):
228
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
229
+ provided, text embeddings will be generated from `prompt` input argument.
230
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
231
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
232
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
233
+ argument.
234
+ lora_scale (`float`, *optional*):
235
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
236
+ clip_skip (`int`, *optional*):
237
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
238
+ the output of the pre-final layer will be used for computing the prompt embeddings.
239
+ """
240
+ # set lora scale so that monkey patched LoRA
241
+ # function of text encoder can correctly access it
242
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
243
+ self._lora_scale = lora_scale
244
+
245
+ # dynamically adjust the LoRA scale
246
+ if not USE_PEFT_BACKEND:
247
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
248
+ else:
249
+ scale_lora_layers(self.text_encoder, lora_scale)
250
+
251
+ if prompt is not None and isinstance(prompt, str):
252
+ batch_size = 1
253
+ elif prompt is not None and isinstance(prompt, list):
254
+ batch_size = len(prompt)
255
+ else:
256
+ batch_size = prompt_embeds.shape[0]
257
+
258
+ if prompt_embeds is None:
259
+ # textual inversion: process multi-vector tokens if necessary
260
+ if isinstance(self, TextualInversionLoaderMixin):
261
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
262
+
263
+ text_inputs = self.tokenizer(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=self.tokenizer.model_max_length,
267
+ truncation=True,
268
+ return_tensors="pt",
269
+ )
270
+ text_input_ids = text_inputs.input_ids
271
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
272
+
273
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
274
+ text_input_ids, untruncated_ids
275
+ ):
276
+ removed_text = self.tokenizer.batch_decode(
277
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
278
+ )
279
+ logger.warning(
280
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
281
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
282
+ )
283
+
284
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
285
+ attention_mask = text_inputs.attention_mask.to(device)
286
+ else:
287
+ attention_mask = None
288
+
289
+ if clip_skip is None:
290
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
291
+ prompt_embeds = prompt_embeds[0]
292
+ else:
293
+ prompt_embeds = self.text_encoder(
294
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
295
+ )
296
+ # Access the `hidden_states` first, that contains a tuple of
297
+ # all the hidden states from the encoder layers. Then index into
298
+ # the tuple to access the hidden states from the desired layer.
299
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
300
+ # We also need to apply the final LayerNorm here to not mess with the
301
+ # representations. The `last_hidden_states` that we typically use for
302
+ # obtaining the final prompt representations passes through the LayerNorm
303
+ # layer.
304
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
305
+
306
+ if self.text_encoder is not None:
307
+ prompt_embeds_dtype = self.text_encoder.dtype
308
+ elif self.unet is not None:
309
+ prompt_embeds_dtype = self.unet.dtype
310
+ else:
311
+ prompt_embeds_dtype = prompt_embeds.dtype
312
+
313
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
314
+
315
+ bs_embed, seq_len, _ = prompt_embeds.shape
316
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
317
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
318
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
319
+
320
+ # get unconditional embeddings for classifier free guidance
321
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
322
+ uncond_tokens: List[str]
323
+ if negative_prompt is None:
324
+ uncond_tokens = [""] * batch_size
325
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
326
+ raise TypeError(
327
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
328
+ f" {type(prompt)}."
329
+ )
330
+ elif isinstance(negative_prompt, str):
331
+ uncond_tokens = [negative_prompt]
332
+ elif batch_size != len(negative_prompt):
333
+ raise ValueError(
334
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
335
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
336
+ " the batch size of `prompt`."
337
+ )
338
+ else:
339
+ uncond_tokens = negative_prompt
340
+
341
+ # textual inversion: process multi-vector tokens if necessary
342
+ if isinstance(self, TextualInversionLoaderMixin):
343
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
344
+
345
+ max_length = prompt_embeds.shape[1]
346
+ uncond_input = self.tokenizer(
347
+ uncond_tokens,
348
+ padding="max_length",
349
+ max_length=max_length,
350
+ truncation=True,
351
+ return_tensors="pt",
352
+ )
353
+
354
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
355
+ attention_mask = uncond_input.attention_mask.to(device)
356
+ else:
357
+ attention_mask = None
358
+
359
+ negative_prompt_embeds = self.text_encoder(
360
+ uncond_input.input_ids.to(device),
361
+ attention_mask=attention_mask,
362
+ )
363
+ negative_prompt_embeds = negative_prompt_embeds[0]
364
+
365
+ if do_classifier_free_guidance:
366
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
367
+ seq_len = negative_prompt_embeds.shape[1]
368
+
369
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
370
+
371
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
372
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
373
+
374
+ if self.text_encoder is not None:
375
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
376
+ # Retrieve the original scale by scaling back the LoRA layers
377
+ unscale_lora_layers(self.text_encoder, lora_scale)
378
+
379
+ return prompt_embeds, negative_prompt_embeds
380
+
381
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
382
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
383
+ dtype = next(self.image_encoder.parameters()).dtype
384
+
385
+ if not isinstance(image, torch.Tensor):
386
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
387
+
388
+ image = image.to(device=device, dtype=dtype)
389
+ if output_hidden_states:
390
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
391
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
392
+ uncond_image_enc_hidden_states = self.image_encoder(
393
+ torch.zeros_like(image), output_hidden_states=True
394
+ ).hidden_states[-2]
395
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
396
+ num_images_per_prompt, dim=0
397
+ )
398
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
399
+ else:
400
+ image_embeds = self.image_encoder(image).image_embeds
401
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
402
+ uncond_image_embeds = torch.zeros_like(image_embeds)
403
+
404
+ return image_embeds, uncond_image_embeds
405
+
406
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
407
+ def prepare_ip_adapter_image_embeds(
408
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
409
+ ):
410
+ image_embeds = []
411
+ if do_classifier_free_guidance:
412
+ negative_image_embeds = []
413
+ if ip_adapter_image_embeds is None:
414
+ if not isinstance(ip_adapter_image, list):
415
+ ip_adapter_image = [ip_adapter_image]
416
+
417
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
418
+ raise ValueError(
419
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
420
+ )
421
+
422
+ for single_ip_adapter_image, image_proj_layer in zip(
423
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
424
+ ):
425
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
426
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
427
+ single_ip_adapter_image, device, 1, output_hidden_state
428
+ )
429
+
430
+ image_embeds.append(single_image_embeds[None, :])
431
+ if do_classifier_free_guidance:
432
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
433
+ else:
434
+ for single_image_embeds in ip_adapter_image_embeds:
435
+ if do_classifier_free_guidance:
436
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
437
+ negative_image_embeds.append(single_negative_image_embeds)
438
+ image_embeds.append(single_image_embeds)
439
+
440
+ ip_adapter_image_embeds = []
441
+ for i, single_image_embeds in enumerate(image_embeds):
442
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
443
+ if do_classifier_free_guidance:
444
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
445
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
446
+
447
+ single_image_embeds = single_image_embeds.to(device=device)
448
+ ip_adapter_image_embeds.append(single_image_embeds)
449
+
450
+ return ip_adapter_image_embeds
451
+
452
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
453
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
454
+ latents = 1 / self.vae.config.scaling_factor * latents
455
+
456
+ batch_size, channels, num_frames, height, width = latents.shape
457
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
458
+
459
+ video = []
460
+ for i in range(0, latents.shape[0], decode_chunk_size):
461
+ batch_latents = latents[i : i + decode_chunk_size]
462
+ batch_latents = self.vae.decode(batch_latents).sample
463
+ video.append(batch_latents)
464
+
465
+ video = torch.cat(video)
466
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
467
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
468
+ video = video.float()
469
+ return video
470
+
471
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
472
+ def prepare_extra_step_kwargs(self, generator, eta):
473
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
474
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
475
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
476
+ # and should be between [0, 1]
477
+
478
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
479
+ extra_step_kwargs = {}
480
+ if accepts_eta:
481
+ extra_step_kwargs["eta"] = eta
482
+
483
+ # check if the scheduler accepts generator
484
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
485
+ if accepts_generator:
486
+ extra_step_kwargs["generator"] = generator
487
+ return extra_step_kwargs
488
+
489
+ def check_inputs(
490
+ self,
491
+ prompt,
492
+ height,
493
+ width,
494
+ num_frames,
495
+ negative_prompt=None,
496
+ prompt_embeds=None,
497
+ negative_prompt_embeds=None,
498
+ callback_on_step_end_tensor_inputs=None,
499
+ video=None,
500
+ controlnet_conditioning_scale=1.0,
501
+ control_guidance_start=0.0,
502
+ control_guidance_end=1.0,
503
+ ):
504
+ if height % 8 != 0 or width % 8 != 0:
505
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
506
+
507
+ if callback_on_step_end_tensor_inputs is not None and not all(
508
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
509
+ ):
510
+ raise ValueError(
511
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
512
+ )
513
+
514
+ if prompt is not None and prompt_embeds is not None:
515
+ raise ValueError(
516
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
517
+ " only forward one of the two."
518
+ )
519
+ elif prompt is None and prompt_embeds is None:
520
+ raise ValueError(
521
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
522
+ )
523
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
524
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
525
+
526
+ if negative_prompt is not None and negative_prompt_embeds is not None:
527
+ raise ValueError(
528
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
529
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
530
+ )
531
+
532
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
533
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
534
+ raise ValueError(
535
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
536
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
537
+ f" {negative_prompt_embeds.shape}."
538
+ )
539
+
540
+ # `prompt` needs more sophisticated handling when there are multiple
541
+ # conditionings.
542
+ if isinstance(self.controlnet, MultiControlNetModel):
543
+ if isinstance(prompt, list):
544
+ logger.warning(
545
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
546
+ " prompts. The conditionings will be fixed across the prompts."
547
+ )
548
+
549
+ # Check `image`
550
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
551
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
552
+ )
553
+ if (
554
+ isinstance(self.controlnet, ControlNetModel)
555
+ or is_compiled
556
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
557
+ ):
558
+ if not isinstance(video, list):
559
+ raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}")
560
+ if len(video) != num_frames:
561
+ raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}")
562
+ elif (
563
+ isinstance(self.controlnet, MultiControlNetModel)
564
+ or is_compiled
565
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
566
+ ):
567
+ if not isinstance(video, list) or not isinstance(video[0], list):
568
+ raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}")
569
+ if len(video[0]) != num_frames:
570
+ raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}")
571
+ if any(len(img) != len(video[0]) for img in video):
572
+ raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
573
+ else:
574
+ assert False
575
+
576
+ # Check `controlnet_conditioning_scale`
577
+ if (
578
+ isinstance(self.controlnet, ControlNetModel)
579
+ or is_compiled
580
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
581
+ ):
582
+ if not isinstance(controlnet_conditioning_scale, float):
583
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
584
+ elif (
585
+ isinstance(self.controlnet, MultiControlNetModel)
586
+ or is_compiled
587
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
588
+ ):
589
+ if isinstance(controlnet_conditioning_scale, list):
590
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
591
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
592
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
593
+ self.controlnet.nets
594
+ ):
595
+ raise ValueError(
596
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
597
+ " the same length as the number of controlnets"
598
+ )
599
+ else:
600
+ assert False
601
+
602
+ if not isinstance(control_guidance_start, (tuple, list)):
603
+ control_guidance_start = [control_guidance_start]
604
+
605
+ if not isinstance(control_guidance_end, (tuple, list)):
606
+ control_guidance_end = [control_guidance_end]
607
+
608
+ if len(control_guidance_start) != len(control_guidance_end):
609
+ raise ValueError(
610
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
611
+ )
612
+
613
+ if isinstance(self.controlnet, MultiControlNetModel):
614
+ if len(control_guidance_start) != len(self.controlnet.nets):
615
+ raise ValueError(
616
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
617
+ )
618
+
619
+ for start, end in zip(control_guidance_start, control_guidance_end):
620
+ if start >= end:
621
+ raise ValueError(
622
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
623
+ )
624
+ if start < 0.0:
625
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
626
+ if end > 1.0:
627
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
628
+
629
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
630
+ def prepare_latents(
631
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
632
+ ):
633
+ # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169)
634
+ if self.free_noise_enabled:
635
+ latents = self._prepare_latents_free_noise(
636
+ batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
637
+ )
638
+
639
+ if isinstance(generator, list) and len(generator) != batch_size:
640
+ raise ValueError(
641
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
642
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
643
+ )
644
+
645
+ shape = (
646
+ batch_size,
647
+ num_channels_latents,
648
+ num_frames,
649
+ height // self.vae_scale_factor,
650
+ width // self.vae_scale_factor,
651
+ )
652
+
653
+ if latents is None:
654
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
655
+ else:
656
+ latents = latents.to(device)
657
+
658
+ # scale the initial noise by the standard deviation required by the scheduler
659
+ latents = latents * self.scheduler.init_noise_sigma
660
+ return latents
661
+
662
+ def prepare_video(
663
+ self,
664
+ video,
665
+ width,
666
+ height,
667
+ batch_size,
668
+ num_videos_per_prompt,
669
+ device,
670
+ dtype,
671
+ do_classifier_free_guidance=False,
672
+ guess_mode=False,
673
+ ):
674
+ video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
675
+ dtype=torch.float32
676
+ )
677
+ video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
678
+ video_batch_size = video.shape[0]
679
+
680
+ if video_batch_size == 1:
681
+ repeat_by = batch_size
682
+ else:
683
+ # image batch size is the same as prompt batch size
684
+ repeat_by = num_videos_per_prompt
685
+
686
+ video = video.repeat_interleave(repeat_by, dim=0)
687
+ video = video.to(device=device, dtype=dtype)
688
+
689
+ if do_classifier_free_guidance and not guess_mode:
690
+ video = torch.cat([video] * 2)
691
+
692
+ return video
693
+
694
+ @property
695
+ def guidance_scale(self):
696
+ return self._guidance_scale
697
+
698
+ @property
699
+ def clip_skip(self):
700
+ return self._clip_skip
701
+
702
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
703
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
704
+ # corresponds to doing no classifier free guidance.
705
+ @property
706
+ def do_classifier_free_guidance(self):
707
+ return self._guidance_scale > 1
708
+
709
+ @property
710
+ def cross_attention_kwargs(self):
711
+ return self._cross_attention_kwargs
712
+
713
+ @property
714
+ def num_timesteps(self):
715
+ return self._num_timesteps
716
+
717
+ @property
718
+ def interrupt(self):
719
+ return self._interrupt
720
+
721
+ @torch.no_grad()
722
+ def __call__(
723
+ self,
724
+ prompt: Union[str, List[str]] = None,
725
+ num_frames: Optional[int] = 16,
726
+ height: Optional[int] = None,
727
+ width: Optional[int] = None,
728
+ num_inference_steps: int = 50,
729
+ guidance_scale: float = 7.5,
730
+ negative_prompt: Optional[Union[str, List[str]]] = None,
731
+ num_videos_per_prompt: Optional[int] = 1,
732
+ eta: float = 0.0,
733
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
734
+ latents: Optional[torch.Tensor] = None,
735
+ prompt_embeds: Optional[torch.Tensor] = None,
736
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
737
+ ip_adapter_image: Optional[PipelineImageInput] = None,
738
+ ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
739
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
740
+ output_type: Optional[str] = "pil",
741
+ return_dict: bool = True,
742
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
743
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
744
+ guess_mode: bool = False,
745
+ control_guidance_start: Union[float, List[float]] = 0.0,
746
+ control_guidance_end: Union[float, List[float]] = 1.0,
747
+ clip_skip: Optional[int] = None,
748
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
749
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
750
+ decode_chunk_size: int = 16,
751
+ ):
752
+ r"""
753
+ The call function to the pipeline for generation.
754
+
755
+ Args:
756
+ prompt (`str` or `List[str]`, *optional*):
757
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
758
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
759
+ The height in pixels of the generated video.
760
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
761
+ The width in pixels of the generated video.
762
+ num_frames (`int`, *optional*, defaults to 16):
763
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
764
+ amounts to 2 seconds of video.
765
+ num_inference_steps (`int`, *optional*, defaults to 50):
766
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
767
+ expense of slower inference.
768
+ guidance_scale (`float`, *optional*, defaults to 7.5):
769
+ A higher guidance scale value encourages the model to generate images closely linked to the text
770
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
771
+ negative_prompt (`str` or `List[str]`, *optional*):
772
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
773
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
774
+ eta (`float`, *optional*, defaults to 0.0):
775
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
776
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
777
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
778
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
779
+ generation deterministic.
780
+ latents (`torch.Tensor`, *optional*):
781
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
782
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
783
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
784
+ `(batch_size, num_channel, num_frames, height, width)`.
785
+ prompt_embeds (`torch.Tensor`, *optional*):
786
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
787
+ provided, text embeddings are generated from the `prompt` input argument.
788
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
789
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
790
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
791
+ ip_adapter_image (`PipelineImageInput`, *optional*):
792
+ Optional image input to work with IP Adapters.
793
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
794
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
795
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
796
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
797
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
798
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
799
+ The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
800
+ ControlNets are specified, images must be passed as a list such that each element of the list can be
801
+ correctly batched for input to a single ControlNet.
802
+ output_type (`str`, *optional*, defaults to `"pil"`):
803
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
804
+ return_dict (`bool`, *optional*, defaults to `True`):
805
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
806
+ of a plain tuple.
807
+ cross_attention_kwargs (`dict`, *optional*):
808
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
809
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
810
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
811
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
812
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
813
+ the corresponding scale as a list.
814
+ guess_mode (`bool`, *optional*, defaults to `False`):
815
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
816
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
817
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
818
+ The percentage of total steps at which the ControlNet starts applying.
819
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
820
+ The percentage of total steps at which the ControlNet stops applying.
821
+ clip_skip (`int`, *optional*):
822
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
823
+ the output of the pre-final layer will be used for computing the prompt embeddings.
824
+ callback_on_step_end (`Callable`, *optional*):
825
+ A function that calls at the end of each denoising steps during the inference. The function is called
826
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
827
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
828
+ `callback_on_step_end_tensor_inputs`.
829
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
830
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
831
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
832
+ `._callback_tensor_inputs` attribute of your pipeline class.
833
+
834
+ Examples:
835
+
836
+ Returns:
837
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
838
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
839
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
840
+ """
841
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
842
+
843
+ # align format for control guidance
844
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
845
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
846
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
847
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
848
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
849
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
850
+ control_guidance_start, control_guidance_end = (
851
+ mult * [control_guidance_start],
852
+ mult * [control_guidance_end],
853
+ )
854
+
855
+ # 0. Default height and width to unet
856
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
857
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
858
+
859
+ num_videos_per_prompt = 1
860
+
861
+ # 1. Check inputs. Raise error if not correct
862
+ self.check_inputs(
863
+ prompt=prompt,
864
+ height=height,
865
+ width=width,
866
+ num_frames=num_frames,
867
+ negative_prompt=negative_prompt,
868
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
869
+ prompt_embeds=prompt_embeds,
870
+ negative_prompt_embeds=negative_prompt_embeds,
871
+ video=conditioning_frames,
872
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
873
+ control_guidance_start=control_guidance_start,
874
+ control_guidance_end=control_guidance_end,
875
+ )
876
+
877
+ self._guidance_scale = guidance_scale
878
+ self._clip_skip = clip_skip
879
+ self._cross_attention_kwargs = cross_attention_kwargs
880
+ self._interrupt = False
881
+
882
+ # 2. Define call parameters
883
+ if prompt is not None and isinstance(prompt, (str, dict)):
884
+ batch_size = 1
885
+ elif prompt is not None and isinstance(prompt, list):
886
+ batch_size = len(prompt)
887
+ else:
888
+ batch_size = prompt_embeds.shape[0]
889
+
890
+ device = self._execution_device
891
+
892
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
893
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
894
+
895
+ global_pool_conditions = (
896
+ controlnet.config.global_pool_conditions
897
+ if isinstance(controlnet, ControlNetModel)
898
+ else controlnet.nets[0].config.global_pool_conditions
899
+ )
900
+ guess_mode = guess_mode or global_pool_conditions
901
+
902
+ # 3. Encode input prompt
903
+ text_encoder_lora_scale = (
904
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
905
+ )
906
+ if self.free_noise_enabled:
907
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
908
+ prompt=prompt,
909
+ num_frames=num_frames,
910
+ device=device,
911
+ num_videos_per_prompt=num_videos_per_prompt,
912
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
913
+ negative_prompt=negative_prompt,
914
+ prompt_embeds=prompt_embeds,
915
+ negative_prompt_embeds=negative_prompt_embeds,
916
+ lora_scale=text_encoder_lora_scale,
917
+ clip_skip=self.clip_skip,
918
+ )
919
+ else:
920
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
921
+ prompt,
922
+ device,
923
+ num_videos_per_prompt,
924
+ self.do_classifier_free_guidance,
925
+ negative_prompt,
926
+ prompt_embeds=prompt_embeds,
927
+ negative_prompt_embeds=negative_prompt_embeds,
928
+ lora_scale=text_encoder_lora_scale,
929
+ clip_skip=self.clip_skip,
930
+ )
931
+
932
+ # For classifier free guidance, we need to do two forward passes.
933
+ # Here we concatenate the unconditional and text embeddings into a single batch
934
+ # to avoid doing two forward passes
935
+ if self.do_classifier_free_guidance:
936
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
937
+
938
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
939
+
940
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
941
+ image_embeds = self.prepare_ip_adapter_image_embeds(
942
+ ip_adapter_image,
943
+ ip_adapter_image_embeds,
944
+ device,
945
+ batch_size * num_videos_per_prompt,
946
+ self.do_classifier_free_guidance,
947
+ )
948
+
949
+ if isinstance(controlnet, ControlNetModel):
950
+ conditioning_frames = self.prepare_video(
951
+ video=conditioning_frames,
952
+ width=width,
953
+ height=height,
954
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
955
+ num_videos_per_prompt=num_videos_per_prompt,
956
+ device=device,
957
+ dtype=controlnet.dtype,
958
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
959
+ guess_mode=guess_mode,
960
+ )
961
+ elif isinstance(controlnet, MultiControlNetModel):
962
+ cond_prepared_videos = []
963
+ for frame_ in conditioning_frames:
964
+ prepared_video = self.prepare_video(
965
+ video=frame_,
966
+ width=width,
967
+ height=height,
968
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
969
+ num_videos_per_prompt=num_videos_per_prompt,
970
+ device=device,
971
+ dtype=controlnet.dtype,
972
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
973
+ guess_mode=guess_mode,
974
+ )
975
+ cond_prepared_videos.append(prepared_video)
976
+ conditioning_frames = cond_prepared_videos
977
+ else:
978
+ assert False
979
+
980
+ # 4. Prepare timesteps
981
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
982
+ timesteps = self.scheduler.timesteps
983
+
984
+ # 5. Prepare latent variables
985
+ num_channels_latents = self.unet.config.in_channels
986
+ latents = self.prepare_latents(
987
+ batch_size * num_videos_per_prompt,
988
+ num_channels_latents,
989
+ num_frames,
990
+ height,
991
+ width,
992
+ prompt_embeds.dtype,
993
+ device,
994
+ generator,
995
+ latents,
996
+ )
997
+
998
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
999
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1000
+
1001
+ # 7. Add image embeds for IP-Adapter
1002
+ added_cond_kwargs = (
1003
+ {"image_embeds": image_embeds}
1004
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1005
+ else None
1006
+ )
1007
+
1008
+ # 7.1 Create tensor stating which controlnets to keep
1009
+ controlnet_keep = []
1010
+ for i in range(len(timesteps)):
1011
+ keeps = [
1012
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1013
+ for s, e in zip(control_guidance_start, control_guidance_end)
1014
+ ]
1015
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1016
+
1017
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
1018
+ for free_init_iter in range(num_free_init_iters):
1019
+ if self.free_init_enabled:
1020
+ latents, timesteps = self._apply_free_init(
1021
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
1022
+ )
1023
+
1024
+ self._num_timesteps = len(timesteps)
1025
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1026
+
1027
+ # 8. Denoising loop
1028
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1029
+ for i, t in enumerate(timesteps):
1030
+ if self.interrupt:
1031
+ continue
1032
+
1033
+ # expand the latents if we are doing classifier free guidance
1034
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1035
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1036
+
1037
+ if guess_mode and self.do_classifier_free_guidance:
1038
+ # Infer ControlNet only for the conditional batch.
1039
+ control_model_input = latents
1040
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1041
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1042
+ else:
1043
+ control_model_input = latent_model_input
1044
+ controlnet_prompt_embeds = prompt_embeds
1045
+
1046
+ if isinstance(controlnet_keep[i], list):
1047
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1048
+ else:
1049
+ controlnet_cond_scale = controlnet_conditioning_scale
1050
+ if isinstance(controlnet_cond_scale, list):
1051
+ controlnet_cond_scale = controlnet_cond_scale[0]
1052
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1053
+
1054
+ control_model_input = torch.transpose(control_model_input, 1, 2)
1055
+ control_model_input = control_model_input.reshape(
1056
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1057
+ )
1058
+
1059
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1060
+ control_model_input,
1061
+ t,
1062
+ encoder_hidden_states=controlnet_prompt_embeds,
1063
+ controlnet_cond=conditioning_frames,
1064
+ conditioning_scale=cond_scale,
1065
+ guess_mode=guess_mode,
1066
+ return_dict=False,
1067
+ )
1068
+
1069
+ # predict the noise residual
1070
+ noise_pred = self.unet(
1071
+ latent_model_input,
1072
+ t,
1073
+ encoder_hidden_states=prompt_embeds,
1074
+ cross_attention_kwargs=self.cross_attention_kwargs,
1075
+ added_cond_kwargs=added_cond_kwargs,
1076
+ down_block_additional_residuals=down_block_res_samples,
1077
+ mid_block_additional_residual=mid_block_res_sample,
1078
+ ).sample
1079
+
1080
+ # perform guidance
1081
+ if self.do_classifier_free_guidance:
1082
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1083
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1084
+
1085
+ # compute the previous noisy sample x_t -> x_t-1
1086
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1087
+
1088
+ if callback_on_step_end is not None:
1089
+ callback_kwargs = {}
1090
+ for k in callback_on_step_end_tensor_inputs:
1091
+ callback_kwargs[k] = locals()[k]
1092
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1093
+
1094
+ latents = callback_outputs.pop("latents", latents)
1095
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1096
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1097
+
1098
+ # call the callback, if provided
1099
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1100
+ progress_bar.update()
1101
+
1102
+ if XLA_AVAILABLE:
1103
+ xm.mark_step()
1104
+
1105
+ # 9. Post processing
1106
+ if output_type == "latent":
1107
+ video = latents
1108
+ else:
1109
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
1110
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1111
+
1112
+ # 10. Offload all models
1113
+ self.maybe_free_model_hooks()
1114
+
1115
+ if not return_dict:
1116
+ return (video,)
1117
+
1118
+ return AnimateDiffPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py ADDED
@@ -0,0 +1,1309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPImageProcessor,
21
+ CLIPTextModel,
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ )
26
+
27
+ from ...image_processor import PipelineImageInput
28
+ from ...loaders import (
29
+ FromSingleFileMixin,
30
+ IPAdapterMixin,
31
+ StableDiffusionXLLoraLoaderMixin,
32
+ TextualInversionLoaderMixin,
33
+ )
34
+ from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel
35
+ from ...models.attention_processor import (
36
+ AttnProcessor2_0,
37
+ FusedAttnProcessor2_0,
38
+ XFormersAttnProcessor,
39
+ )
40
+ from ...models.lora import adjust_lora_scale_text_encoder
41
+ from ...schedulers import (
42
+ DDIMScheduler,
43
+ DPMSolverMultistepScheduler,
44
+ EulerAncestralDiscreteScheduler,
45
+ EulerDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ )
49
+ from ...utils import (
50
+ USE_PEFT_BACKEND,
51
+ is_torch_xla_available,
52
+ logging,
53
+ replace_example_docstring,
54
+ scale_lora_layers,
55
+ unscale_lora_layers,
56
+ )
57
+ from ...utils.torch_utils import randn_tensor
58
+ from ...video_processor import VideoProcessor
59
+ from ..free_init_utils import FreeInitMixin
60
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
61
+ from .pipeline_output import AnimateDiffPipelineOutput
62
+
63
+
64
+ if is_torch_xla_available():
65
+ import torch_xla.core.xla_model as xm
66
+
67
+ XLA_AVAILABLE = True
68
+ else:
69
+ XLA_AVAILABLE = False
70
+
71
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
72
+
73
+
74
+ EXAMPLE_DOC_STRING = """
75
+ Examples:
76
+ ```py
77
+ >>> import torch
78
+ >>> from diffusers.models import MotionAdapter
79
+ >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
80
+ >>> from diffusers.utils import export_to_gif
81
+
82
+ >>> adapter = MotionAdapter.from_pretrained(
83
+ ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16
84
+ ... )
85
+
86
+ >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0"
87
+ >>> scheduler = DDIMScheduler.from_pretrained(
88
+ ... model_id,
89
+ ... subfolder="scheduler",
90
+ ... clip_sample=False,
91
+ ... timestep_spacing="linspace",
92
+ ... beta_schedule="linear",
93
+ ... steps_offset=1,
94
+ ... )
95
+ >>> pipe = AnimateDiffSDXLPipeline.from_pretrained(
96
+ ... model_id,
97
+ ... motion_adapter=adapter,
98
+ ... scheduler=scheduler,
99
+ ... torch_dtype=torch.float16,
100
+ ... variant="fp16",
101
+ ... ).to("cuda")
102
+
103
+ >>> # enable memory savings
104
+ >>> pipe.enable_vae_slicing()
105
+ >>> pipe.enable_vae_tiling()
106
+
107
+ >>> output = pipe(
108
+ ... prompt="a panda surfing in the ocean, realistic, high quality",
109
+ ... negative_prompt="low quality, worst quality",
110
+ ... num_inference_steps=20,
111
+ ... guidance_scale=8,
112
+ ... width=1024,
113
+ ... height=1024,
114
+ ... num_frames=16,
115
+ ... )
116
+
117
+ >>> frames = output.frames[0]
118
+ >>> export_to_gif(frames, "animation.gif")
119
+ ```
120
+ """
121
+
122
+
123
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
124
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
125
+ r"""
126
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
127
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
128
+ Flawed](https://huggingface.co/papers/2305.08891).
129
+
130
+ Args:
131
+ noise_cfg (`torch.Tensor`):
132
+ The predicted noise tensor for the guided diffusion process.
133
+ noise_pred_text (`torch.Tensor`):
134
+ The predicted noise tensor for the text-guided diffusion process.
135
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
136
+ A rescale factor applied to the noise predictions.
137
+
138
+ Returns:
139
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
140
+ """
141
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
142
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
143
+ # rescale the results from guidance (fixes overexposure)
144
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
145
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
146
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
147
+ return noise_cfg
148
+
149
+
150
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
151
+ def retrieve_timesteps(
152
+ scheduler,
153
+ num_inference_steps: Optional[int] = None,
154
+ device: Optional[Union[str, torch.device]] = None,
155
+ timesteps: Optional[List[int]] = None,
156
+ sigmas: Optional[List[float]] = None,
157
+ **kwargs,
158
+ ):
159
+ r"""
160
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
161
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
162
+
163
+ Args:
164
+ scheduler (`SchedulerMixin`):
165
+ The scheduler to get timesteps from.
166
+ num_inference_steps (`int`):
167
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
168
+ must be `None`.
169
+ device (`str` or `torch.device`, *optional*):
170
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
171
+ timesteps (`List[int]`, *optional*):
172
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
173
+ `num_inference_steps` and `sigmas` must be `None`.
174
+ sigmas (`List[float]`, *optional*):
175
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
176
+ `num_inference_steps` and `timesteps` must be `None`.
177
+
178
+ Returns:
179
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
180
+ second element is the number of inference steps.
181
+ """
182
+ if timesteps is not None and sigmas is not None:
183
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
184
+ if timesteps is not None:
185
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
186
+ if not accepts_timesteps:
187
+ raise ValueError(
188
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
189
+ f" timestep schedules. Please check whether you are using the correct scheduler."
190
+ )
191
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
192
+ timesteps = scheduler.timesteps
193
+ num_inference_steps = len(timesteps)
194
+ elif sigmas is not None:
195
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
196
+ if not accept_sigmas:
197
+ raise ValueError(
198
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
199
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
200
+ )
201
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
202
+ timesteps = scheduler.timesteps
203
+ num_inference_steps = len(timesteps)
204
+ else:
205
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
206
+ timesteps = scheduler.timesteps
207
+ return timesteps, num_inference_steps
208
+
209
+
210
+ class AnimateDiffSDXLPipeline(
211
+ DiffusionPipeline,
212
+ StableDiffusionMixin,
213
+ FromSingleFileMixin,
214
+ StableDiffusionXLLoraLoaderMixin,
215
+ TextualInversionLoaderMixin,
216
+ IPAdapterMixin,
217
+ FreeInitMixin,
218
+ ):
219
+ r"""
220
+ Pipeline for text-to-video generation using Stable Diffusion XL.
221
+
222
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
223
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
224
+
225
+ The pipeline also inherits the following loading methods:
226
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
227
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
228
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
229
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
230
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
231
+
232
+ Args:
233
+ vae ([`AutoencoderKL`]):
234
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
235
+ text_encoder ([`CLIPTextModel`]):
236
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
237
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
238
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
239
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
240
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
241
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
242
+ specifically the
243
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
244
+ variant.
245
+ tokenizer (`CLIPTokenizer`):
246
+ Tokenizer of class
247
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
248
+ tokenizer_2 (`CLIPTokenizer`):
249
+ Second Tokenizer of class
250
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
251
+ unet ([`UNet2DConditionModel`]):
252
+ Conditional U-Net architecture to denoise the encoded image latents.
253
+ scheduler ([`SchedulerMixin`]):
254
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
255
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
256
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
257
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
258
+ `stabilityai/stable-diffusion-xl-base-1-0`.
259
+ """
260
+
261
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
262
+ _optional_components = [
263
+ "tokenizer",
264
+ "tokenizer_2",
265
+ "text_encoder",
266
+ "text_encoder_2",
267
+ "image_encoder",
268
+ "feature_extractor",
269
+ ]
270
+ _callback_tensor_inputs = [
271
+ "latents",
272
+ "prompt_embeds",
273
+ "negative_prompt_embeds",
274
+ "add_text_embeds",
275
+ "add_time_ids",
276
+ "negative_pooled_prompt_embeds",
277
+ "negative_add_time_ids",
278
+ ]
279
+
280
+ def __init__(
281
+ self,
282
+ vae: AutoencoderKL,
283
+ text_encoder: CLIPTextModel,
284
+ text_encoder_2: CLIPTextModelWithProjection,
285
+ tokenizer: CLIPTokenizer,
286
+ tokenizer_2: CLIPTokenizer,
287
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
288
+ motion_adapter: MotionAdapter,
289
+ scheduler: Union[
290
+ DDIMScheduler,
291
+ PNDMScheduler,
292
+ LMSDiscreteScheduler,
293
+ EulerDiscreteScheduler,
294
+ EulerAncestralDiscreteScheduler,
295
+ DPMSolverMultistepScheduler,
296
+ ],
297
+ image_encoder: CLIPVisionModelWithProjection = None,
298
+ feature_extractor: CLIPImageProcessor = None,
299
+ force_zeros_for_empty_prompt: bool = True,
300
+ ):
301
+ super().__init__()
302
+
303
+ if isinstance(unet, UNet2DConditionModel):
304
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
305
+
306
+ self.register_modules(
307
+ vae=vae,
308
+ text_encoder=text_encoder,
309
+ text_encoder_2=text_encoder_2,
310
+ tokenizer=tokenizer,
311
+ tokenizer_2=tokenizer_2,
312
+ unet=unet,
313
+ motion_adapter=motion_adapter,
314
+ scheduler=scheduler,
315
+ image_encoder=image_encoder,
316
+ feature_extractor=feature_extractor,
317
+ )
318
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
319
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
320
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
321
+
322
+ self.default_sample_size = (
323
+ self.unet.config.sample_size
324
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
325
+ else 128
326
+ )
327
+
328
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt
329
+ def encode_prompt(
330
+ self,
331
+ prompt: str,
332
+ prompt_2: Optional[str] = None,
333
+ device: Optional[torch.device] = None,
334
+ num_videos_per_prompt: int = 1,
335
+ do_classifier_free_guidance: bool = True,
336
+ negative_prompt: Optional[str] = None,
337
+ negative_prompt_2: Optional[str] = None,
338
+ prompt_embeds: Optional[torch.Tensor] = None,
339
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
340
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
341
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
342
+ lora_scale: Optional[float] = None,
343
+ clip_skip: Optional[int] = None,
344
+ ):
345
+ r"""
346
+ Encodes the prompt into text encoder hidden states.
347
+
348
+ Args:
349
+ prompt (`str` or `List[str]`, *optional*):
350
+ prompt to be encoded
351
+ prompt_2 (`str` or `List[str]`, *optional*):
352
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
353
+ used in both text-encoders
354
+ device: (`torch.device`):
355
+ torch device
356
+ num_videos_per_prompt (`int`):
357
+ number of images that should be generated per prompt
358
+ do_classifier_free_guidance (`bool`):
359
+ whether to use classifier free guidance or not
360
+ negative_prompt (`str` or `List[str]`, *optional*):
361
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
362
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
363
+ less than `1`).
364
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
365
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
366
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
367
+ prompt_embeds (`torch.Tensor`, *optional*):
368
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
+ provided, text embeddings will be generated from `prompt` input argument.
370
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
371
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
372
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
373
+ argument.
374
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
375
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
376
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
377
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
378
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
379
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
380
+ input argument.
381
+ lora_scale (`float`, *optional*):
382
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
383
+ clip_skip (`int`, *optional*):
384
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
385
+ the output of the pre-final layer will be used for computing the prompt embeddings.
386
+ """
387
+ device = device or self._execution_device
388
+
389
+ # set lora scale so that monkey patched LoRA
390
+ # function of text encoder can correctly access it
391
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
392
+ self._lora_scale = lora_scale
393
+
394
+ # dynamically adjust the LoRA scale
395
+ if self.text_encoder is not None:
396
+ if not USE_PEFT_BACKEND:
397
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
398
+ else:
399
+ scale_lora_layers(self.text_encoder, lora_scale)
400
+
401
+ if self.text_encoder_2 is not None:
402
+ if not USE_PEFT_BACKEND:
403
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
404
+ else:
405
+ scale_lora_layers(self.text_encoder_2, lora_scale)
406
+
407
+ prompt = [prompt] if isinstance(prompt, str) else prompt
408
+
409
+ if prompt is not None:
410
+ batch_size = len(prompt)
411
+ else:
412
+ batch_size = prompt_embeds.shape[0]
413
+
414
+ # Define tokenizers and text encoders
415
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
416
+ text_encoders = (
417
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
418
+ )
419
+
420
+ if prompt_embeds is None:
421
+ prompt_2 = prompt_2 or prompt
422
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
423
+
424
+ # textual inversion: process multi-vector tokens if necessary
425
+ prompt_embeds_list = []
426
+ prompts = [prompt, prompt_2]
427
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
428
+ if isinstance(self, TextualInversionLoaderMixin):
429
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
430
+
431
+ text_inputs = tokenizer(
432
+ prompt,
433
+ padding="max_length",
434
+ max_length=tokenizer.model_max_length,
435
+ truncation=True,
436
+ return_tensors="pt",
437
+ )
438
+
439
+ text_input_ids = text_inputs.input_ids
440
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
441
+
442
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
443
+ text_input_ids, untruncated_ids
444
+ ):
445
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
446
+ logger.warning(
447
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
448
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
449
+ )
450
+
451
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
452
+
453
+ # We are only ALWAYS interested in the pooled output of the final text encoder
454
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
455
+ pooled_prompt_embeds = prompt_embeds[0]
456
+
457
+ if clip_skip is None:
458
+ prompt_embeds = prompt_embeds.hidden_states[-2]
459
+ else:
460
+ # "2" because SDXL always indexes from the penultimate layer.
461
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
462
+
463
+ prompt_embeds_list.append(prompt_embeds)
464
+
465
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
466
+
467
+ # get unconditional embeddings for classifier free guidance
468
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
469
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
470
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
471
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
472
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
473
+ negative_prompt = negative_prompt or ""
474
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
475
+
476
+ # normalize str to list
477
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
478
+ negative_prompt_2 = (
479
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
480
+ )
481
+
482
+ uncond_tokens: List[str]
483
+ if prompt is not None and type(prompt) is not type(negative_prompt):
484
+ raise TypeError(
485
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
486
+ f" {type(prompt)}."
487
+ )
488
+ elif batch_size != len(negative_prompt):
489
+ raise ValueError(
490
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
491
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
492
+ " the batch size of `prompt`."
493
+ )
494
+ else:
495
+ uncond_tokens = [negative_prompt, negative_prompt_2]
496
+
497
+ negative_prompt_embeds_list = []
498
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
499
+ if isinstance(self, TextualInversionLoaderMixin):
500
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
501
+
502
+ max_length = prompt_embeds.shape[1]
503
+ uncond_input = tokenizer(
504
+ negative_prompt,
505
+ padding="max_length",
506
+ max_length=max_length,
507
+ truncation=True,
508
+ return_tensors="pt",
509
+ )
510
+
511
+ negative_prompt_embeds = text_encoder(
512
+ uncond_input.input_ids.to(device),
513
+ output_hidden_states=True,
514
+ )
515
+
516
+ # We are only ALWAYS interested in the pooled output of the final text encoder
517
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
518
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
519
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
520
+
521
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
522
+
523
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
524
+
525
+ if self.text_encoder_2 is not None:
526
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
527
+ else:
528
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
529
+
530
+ bs_embed, seq_len, _ = prompt_embeds.shape
531
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
532
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
533
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
534
+
535
+ if do_classifier_free_guidance:
536
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
537
+ seq_len = negative_prompt_embeds.shape[1]
538
+
539
+ if self.text_encoder_2 is not None:
540
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
541
+ else:
542
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
543
+
544
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
545
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
546
+
547
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view(
548
+ bs_embed * num_videos_per_prompt, -1
549
+ )
550
+ if do_classifier_free_guidance:
551
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view(
552
+ bs_embed * num_videos_per_prompt, -1
553
+ )
554
+
555
+ if self.text_encoder is not None:
556
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
557
+ # Retrieve the original scale by scaling back the LoRA layers
558
+ unscale_lora_layers(self.text_encoder, lora_scale)
559
+
560
+ if self.text_encoder_2 is not None:
561
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
562
+ # Retrieve the original scale by scaling back the LoRA layers
563
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
564
+
565
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
566
+
567
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
568
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
569
+ dtype = next(self.image_encoder.parameters()).dtype
570
+
571
+ if not isinstance(image, torch.Tensor):
572
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
573
+
574
+ image = image.to(device=device, dtype=dtype)
575
+ if output_hidden_states:
576
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
577
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
578
+ uncond_image_enc_hidden_states = self.image_encoder(
579
+ torch.zeros_like(image), output_hidden_states=True
580
+ ).hidden_states[-2]
581
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
582
+ num_images_per_prompt, dim=0
583
+ )
584
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
585
+ else:
586
+ image_embeds = self.image_encoder(image).image_embeds
587
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
588
+ uncond_image_embeds = torch.zeros_like(image_embeds)
589
+
590
+ return image_embeds, uncond_image_embeds
591
+
592
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
593
+ def prepare_ip_adapter_image_embeds(
594
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
595
+ ):
596
+ image_embeds = []
597
+ if do_classifier_free_guidance:
598
+ negative_image_embeds = []
599
+ if ip_adapter_image_embeds is None:
600
+ if not isinstance(ip_adapter_image, list):
601
+ ip_adapter_image = [ip_adapter_image]
602
+
603
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
604
+ raise ValueError(
605
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
606
+ )
607
+
608
+ for single_ip_adapter_image, image_proj_layer in zip(
609
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
610
+ ):
611
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
612
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
613
+ single_ip_adapter_image, device, 1, output_hidden_state
614
+ )
615
+
616
+ image_embeds.append(single_image_embeds[None, :])
617
+ if do_classifier_free_guidance:
618
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
619
+ else:
620
+ for single_image_embeds in ip_adapter_image_embeds:
621
+ if do_classifier_free_guidance:
622
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
623
+ negative_image_embeds.append(single_negative_image_embeds)
624
+ image_embeds.append(single_image_embeds)
625
+
626
+ ip_adapter_image_embeds = []
627
+ for i, single_image_embeds in enumerate(image_embeds):
628
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
629
+ if do_classifier_free_guidance:
630
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
631
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
632
+
633
+ single_image_embeds = single_image_embeds.to(device=device)
634
+ ip_adapter_image_embeds.append(single_image_embeds)
635
+
636
+ return ip_adapter_image_embeds
637
+
638
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
639
+ def decode_latents(self, latents):
640
+ latents = 1 / self.vae.config.scaling_factor * latents
641
+
642
+ batch_size, channels, num_frames, height, width = latents.shape
643
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
644
+
645
+ image = self.vae.decode(latents).sample
646
+ video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
647
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
648
+ video = video.float()
649
+ return video
650
+
651
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
652
+ def prepare_extra_step_kwargs(self, generator, eta):
653
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
654
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
655
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
656
+ # and should be between [0, 1]
657
+
658
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
659
+ extra_step_kwargs = {}
660
+ if accepts_eta:
661
+ extra_step_kwargs["eta"] = eta
662
+
663
+ # check if the scheduler accepts generator
664
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
665
+ if accepts_generator:
666
+ extra_step_kwargs["generator"] = generator
667
+ return extra_step_kwargs
668
+
669
+ def check_inputs(
670
+ self,
671
+ prompt,
672
+ prompt_2,
673
+ height,
674
+ width,
675
+ negative_prompt=None,
676
+ negative_prompt_2=None,
677
+ prompt_embeds=None,
678
+ negative_prompt_embeds=None,
679
+ pooled_prompt_embeds=None,
680
+ negative_pooled_prompt_embeds=None,
681
+ callback_on_step_end_tensor_inputs=None,
682
+ ):
683
+ if height % 8 != 0 or width % 8 != 0:
684
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
685
+
686
+ if callback_on_step_end_tensor_inputs is not None and not all(
687
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
688
+ ):
689
+ raise ValueError(
690
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
691
+ )
692
+
693
+ if prompt is not None and prompt_embeds is not None:
694
+ raise ValueError(
695
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
696
+ " only forward one of the two."
697
+ )
698
+ elif prompt_2 is not None and prompt_embeds is not None:
699
+ raise ValueError(
700
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
701
+ " only forward one of the two."
702
+ )
703
+ elif prompt is None and prompt_embeds is None:
704
+ raise ValueError(
705
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
706
+ )
707
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
708
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
709
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
710
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
711
+
712
+ if negative_prompt is not None and negative_prompt_embeds is not None:
713
+ raise ValueError(
714
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
715
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
716
+ )
717
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
718
+ raise ValueError(
719
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
720
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
721
+ )
722
+
723
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
724
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
725
+ raise ValueError(
726
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
727
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
728
+ f" {negative_prompt_embeds.shape}."
729
+ )
730
+
731
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
732
+ raise ValueError(
733
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
734
+ )
735
+
736
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
737
+ raise ValueError(
738
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
739
+ )
740
+
741
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
742
+ def prepare_latents(
743
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
744
+ ):
745
+ shape = (
746
+ batch_size,
747
+ num_channels_latents,
748
+ num_frames,
749
+ height // self.vae_scale_factor,
750
+ width // self.vae_scale_factor,
751
+ )
752
+ if isinstance(generator, list) and len(generator) != batch_size:
753
+ raise ValueError(
754
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
755
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
756
+ )
757
+
758
+ if latents is None:
759
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
760
+ else:
761
+ latents = latents.to(device)
762
+
763
+ # scale the initial noise by the standard deviation required by the scheduler
764
+ latents = latents * self.scheduler.init_noise_sigma
765
+ return latents
766
+
767
+ def _get_add_time_ids(
768
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
769
+ ):
770
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
771
+
772
+ passed_add_embed_dim = (
773
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
774
+ )
775
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
776
+
777
+ if expected_add_embed_dim != passed_add_embed_dim:
778
+ raise ValueError(
779
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
780
+ )
781
+
782
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
783
+ return add_time_ids
784
+
785
+ def upcast_vae(self):
786
+ dtype = self.vae.dtype
787
+ self.vae.to(dtype=torch.float32)
788
+ use_torch_2_0_or_xformers = isinstance(
789
+ self.vae.decoder.mid_block.attentions[0].processor,
790
+ (
791
+ AttnProcessor2_0,
792
+ XFormersAttnProcessor,
793
+ FusedAttnProcessor2_0,
794
+ ),
795
+ )
796
+ # if xformers or torch_2_0 is used attention block does not need
797
+ # to be in float32 which can save lots of memory
798
+ if use_torch_2_0_or_xformers:
799
+ self.vae.post_quant_conv.to(dtype)
800
+ self.vae.decoder.conv_in.to(dtype)
801
+ self.vae.decoder.mid_block.to(dtype)
802
+
803
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
804
+ def get_guidance_scale_embedding(
805
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
806
+ ) -> torch.Tensor:
807
+ """
808
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
809
+
810
+ Args:
811
+ w (`torch.Tensor`):
812
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
813
+ embedding_dim (`int`, *optional*, defaults to 512):
814
+ Dimension of the embeddings to generate.
815
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
816
+ Data type of the generated embeddings.
817
+
818
+ Returns:
819
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
820
+ """
821
+ assert len(w.shape) == 1
822
+ w = w * 1000.0
823
+
824
+ half_dim = embedding_dim // 2
825
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
826
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
827
+ emb = w.to(dtype)[:, None] * emb[None, :]
828
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
829
+ if embedding_dim % 2 == 1: # zero pad
830
+ emb = torch.nn.functional.pad(emb, (0, 1))
831
+ assert emb.shape == (w.shape[0], embedding_dim)
832
+ return emb
833
+
834
+ @property
835
+ def guidance_scale(self):
836
+ return self._guidance_scale
837
+
838
+ @property
839
+ def guidance_rescale(self):
840
+ return self._guidance_rescale
841
+
842
+ @property
843
+ def clip_skip(self):
844
+ return self._clip_skip
845
+
846
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
847
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
848
+ # corresponds to doing no classifier free guidance.
849
+ @property
850
+ def do_classifier_free_guidance(self):
851
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
852
+
853
+ @property
854
+ def cross_attention_kwargs(self):
855
+ return self._cross_attention_kwargs
856
+
857
+ @property
858
+ def denoising_end(self):
859
+ return self._denoising_end
860
+
861
+ @property
862
+ def num_timesteps(self):
863
+ return self._num_timesteps
864
+
865
+ @property
866
+ def interrupt(self):
867
+ return self._interrupt
868
+
869
+ @torch.no_grad()
870
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
871
+ def __call__(
872
+ self,
873
+ prompt: Union[str, List[str]] = None,
874
+ prompt_2: Optional[Union[str, List[str]]] = None,
875
+ num_frames: int = 16,
876
+ height: Optional[int] = None,
877
+ width: Optional[int] = None,
878
+ num_inference_steps: int = 50,
879
+ timesteps: List[int] = None,
880
+ sigmas: List[float] = None,
881
+ denoising_end: Optional[float] = None,
882
+ guidance_scale: float = 5.0,
883
+ negative_prompt: Optional[Union[str, List[str]]] = None,
884
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
885
+ num_videos_per_prompt: Optional[int] = 1,
886
+ eta: float = 0.0,
887
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
888
+ latents: Optional[torch.Tensor] = None,
889
+ prompt_embeds: Optional[torch.Tensor] = None,
890
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
891
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
892
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
893
+ ip_adapter_image: Optional[PipelineImageInput] = None,
894
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
895
+ output_type: Optional[str] = "pil",
896
+ return_dict: bool = True,
897
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
898
+ guidance_rescale: float = 0.0,
899
+ original_size: Optional[Tuple[int, int]] = None,
900
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
901
+ target_size: Optional[Tuple[int, int]] = None,
902
+ negative_original_size: Optional[Tuple[int, int]] = None,
903
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
904
+ negative_target_size: Optional[Tuple[int, int]] = None,
905
+ clip_skip: Optional[int] = None,
906
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
907
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
908
+ ):
909
+ r"""
910
+ Function invoked when calling the pipeline for generation.
911
+
912
+ Args:
913
+ prompt (`str` or `List[str]`, *optional*):
914
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
915
+ instead.
916
+ prompt_2 (`str` or `List[str]`, *optional*):
917
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
918
+ used in both text-encoders
919
+ num_frames:
920
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
921
+ amounts to 2 seconds of video.
922
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
923
+ The height in pixels of the generated video. This is set to 1024 by default for the best results.
924
+ Anything below 512 pixels won't work well for
925
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
926
+ and checkpoints that are not specifically fine-tuned on low resolutions.
927
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
928
+ The width in pixels of the generated video. This is set to 1024 by default for the best results.
929
+ Anything below 512 pixels won't work well for
930
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
931
+ and checkpoints that are not specifically fine-tuned on low resolutions.
932
+ num_inference_steps (`int`, *optional*, defaults to 50):
933
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
934
+ expense of slower inference.
935
+ timesteps (`List[int]`, *optional*):
936
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
937
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
938
+ passed will be used. Must be in descending order.
939
+ sigmas (`List[float]`, *optional*):
940
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
941
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
942
+ will be used.
943
+ denoising_end (`float`, *optional*):
944
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
945
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
946
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
947
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
948
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
949
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
950
+ guidance_scale (`float`, *optional*, defaults to 5.0):
951
+ Guidance scale as defined in [Classifier-Free Diffusion
952
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
953
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
954
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
955
+ the text `prompt`, usually at the expense of lower video quality.
956
+ negative_prompt (`str` or `List[str]`, *optional*):
957
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
958
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
959
+ less than `1`).
960
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
961
+ The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and
962
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
963
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
964
+ The number of videos to generate per prompt.
965
+ eta (`float`, *optional*, defaults to 0.0):
966
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
967
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
968
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
969
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
970
+ to make generation deterministic.
971
+ latents (`torch.Tensor`, *optional*):
972
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
973
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
974
+ tensor will be generated by sampling using the supplied random `generator`.
975
+ prompt_embeds (`torch.Tensor`, *optional*):
976
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
977
+ provided, text embeddings will be generated from `prompt` input argument.
978
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
979
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
980
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
981
+ argument.
982
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
983
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
984
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
985
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
986
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
987
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
988
+ input argument.
989
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
990
+ Optional image input to work with IP Adapters.
991
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
992
+ Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the
993
+ `ip_adapter_image` input argument.
994
+ output_type (`str`, *optional*, defaults to `"pil"`):
995
+ The output format of the generated video. Choose between
996
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
997
+ return_dict (`bool`, *optional*, defaults to `True`):
998
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a
999
+ plain tuple.
1000
+ cross_attention_kwargs (`dict`, *optional*):
1001
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1002
+ `self.processor` in
1003
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1004
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1005
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1006
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
1007
+ [Common Diffusion Noise Schedules and Sample Steps are
1008
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
1009
+ using zero terminal SNR.
1010
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1011
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1012
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1013
+ explained in section 2.2 of
1014
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1015
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1016
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1017
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1018
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1019
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1020
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1021
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1022
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1023
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1024
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1025
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1026
+ micro-conditioning as explained in section 2.2 of
1027
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1028
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1029
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1030
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1031
+ micro-conditioning as explained in section 2.2 of
1032
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1033
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1034
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1035
+ To negatively condition the generation process based on a target image resolution. It should be as same
1036
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1037
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1038
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1039
+ callback_on_step_end (`Callable`, *optional*):
1040
+ A function that calls at the end of each denoising steps during the inference. The function is called
1041
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1042
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1043
+ `callback_on_step_end_tensor_inputs`.
1044
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1045
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1046
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1047
+ `._callback_tensor_inputs` attribute of your pipeline class.
1048
+
1049
+ Examples:
1050
+
1051
+ Returns:
1052
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
1053
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
1054
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1055
+ """
1056
+
1057
+ # 0. Default height and width to unet
1058
+ height = height or self.default_sample_size * self.vae_scale_factor
1059
+ width = width or self.default_sample_size * self.vae_scale_factor
1060
+
1061
+ num_videos_per_prompt = 1
1062
+
1063
+ original_size = original_size or (height, width)
1064
+ target_size = target_size or (height, width)
1065
+
1066
+ # 1. Check inputs. Raise error if not correct
1067
+ self.check_inputs(
1068
+ prompt,
1069
+ prompt_2,
1070
+ height,
1071
+ width,
1072
+ negative_prompt,
1073
+ negative_prompt_2,
1074
+ prompt_embeds,
1075
+ negative_prompt_embeds,
1076
+ pooled_prompt_embeds,
1077
+ negative_pooled_prompt_embeds,
1078
+ callback_on_step_end_tensor_inputs,
1079
+ )
1080
+
1081
+ self._guidance_scale = guidance_scale
1082
+ self._guidance_rescale = guidance_rescale
1083
+ self._clip_skip = clip_skip
1084
+ self._cross_attention_kwargs = cross_attention_kwargs
1085
+ self._denoising_end = denoising_end
1086
+ self._interrupt = False
1087
+
1088
+ # 2. Define call parameters
1089
+ if prompt is not None and isinstance(prompt, str):
1090
+ batch_size = 1
1091
+ elif prompt is not None and isinstance(prompt, list):
1092
+ batch_size = len(prompt)
1093
+ else:
1094
+ batch_size = prompt_embeds.shape[0]
1095
+
1096
+ device = self._execution_device
1097
+
1098
+ # 3. Encode input prompt
1099
+ lora_scale = (
1100
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1101
+ )
1102
+
1103
+ (
1104
+ prompt_embeds,
1105
+ negative_prompt_embeds,
1106
+ pooled_prompt_embeds,
1107
+ negative_pooled_prompt_embeds,
1108
+ ) = self.encode_prompt(
1109
+ prompt=prompt,
1110
+ prompt_2=prompt_2,
1111
+ device=device,
1112
+ num_videos_per_prompt=num_videos_per_prompt,
1113
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1114
+ negative_prompt=negative_prompt,
1115
+ negative_prompt_2=negative_prompt_2,
1116
+ prompt_embeds=prompt_embeds,
1117
+ negative_prompt_embeds=negative_prompt_embeds,
1118
+ pooled_prompt_embeds=pooled_prompt_embeds,
1119
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1120
+ lora_scale=lora_scale,
1121
+ clip_skip=self.clip_skip,
1122
+ )
1123
+
1124
+ # 4. Prepare timesteps
1125
+ timesteps, num_inference_steps = retrieve_timesteps(
1126
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1127
+ )
1128
+
1129
+ # 5. Prepare latent variables
1130
+ num_channels_latents = self.unet.config.in_channels
1131
+ latents = self.prepare_latents(
1132
+ batch_size * num_videos_per_prompt,
1133
+ num_channels_latents,
1134
+ num_frames,
1135
+ height,
1136
+ width,
1137
+ prompt_embeds.dtype,
1138
+ device,
1139
+ generator,
1140
+ latents,
1141
+ )
1142
+
1143
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1144
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1145
+
1146
+ # 7. Prepare added time ids & embeddings
1147
+ add_text_embeds = pooled_prompt_embeds
1148
+ if self.text_encoder_2 is None:
1149
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1150
+ else:
1151
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1152
+
1153
+ add_time_ids = self._get_add_time_ids(
1154
+ original_size,
1155
+ crops_coords_top_left,
1156
+ target_size,
1157
+ dtype=prompt_embeds.dtype,
1158
+ text_encoder_projection_dim=text_encoder_projection_dim,
1159
+ )
1160
+ if negative_original_size is not None and negative_target_size is not None:
1161
+ negative_add_time_ids = self._get_add_time_ids(
1162
+ negative_original_size,
1163
+ negative_crops_coords_top_left,
1164
+ negative_target_size,
1165
+ dtype=prompt_embeds.dtype,
1166
+ text_encoder_projection_dim=text_encoder_projection_dim,
1167
+ )
1168
+ else:
1169
+ negative_add_time_ids = add_time_ids
1170
+
1171
+ if self.do_classifier_free_guidance:
1172
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1173
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1174
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1175
+
1176
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
1177
+
1178
+ prompt_embeds = prompt_embeds.to(device)
1179
+ add_text_embeds = add_text_embeds.to(device)
1180
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
1181
+
1182
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1183
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1184
+ ip_adapter_image,
1185
+ ip_adapter_image_embeds,
1186
+ device,
1187
+ batch_size * num_videos_per_prompt,
1188
+ self.do_classifier_free_guidance,
1189
+ )
1190
+
1191
+ # 7.1 Apply denoising_end
1192
+ if (
1193
+ self.denoising_end is not None
1194
+ and isinstance(self.denoising_end, float)
1195
+ and self.denoising_end > 0
1196
+ and self.denoising_end < 1
1197
+ ):
1198
+ discrete_timestep_cutoff = int(
1199
+ round(
1200
+ self.scheduler.config.num_train_timesteps
1201
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1202
+ )
1203
+ )
1204
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1205
+ timesteps = timesteps[:num_inference_steps]
1206
+
1207
+ # 8. Optionally get Guidance Scale Embedding
1208
+ timestep_cond = None
1209
+ if self.unet.config.time_cond_proj_dim is not None:
1210
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt)
1211
+ timestep_cond = self.get_guidance_scale_embedding(
1212
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1213
+ ).to(device=device, dtype=latents.dtype)
1214
+
1215
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
1216
+ for free_init_iter in range(num_free_init_iters):
1217
+ if self.free_init_enabled:
1218
+ latents, timesteps = self._apply_free_init(
1219
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
1220
+ )
1221
+
1222
+ self._num_timesteps = len(timesteps)
1223
+
1224
+ # 9. Denoising loop
1225
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1226
+ for i, t in enumerate(timesteps):
1227
+ if self.interrupt:
1228
+ continue
1229
+
1230
+ # expand the latents if we are doing classifier free guidance
1231
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1232
+
1233
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1234
+
1235
+ # predict the noise residual
1236
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1237
+ if ip_adapter_image is not None or ip_adapter_image_embeds:
1238
+ added_cond_kwargs["image_embeds"] = image_embeds
1239
+
1240
+ noise_pred = self.unet(
1241
+ latent_model_input,
1242
+ t,
1243
+ encoder_hidden_states=prompt_embeds,
1244
+ timestep_cond=timestep_cond,
1245
+ cross_attention_kwargs=self.cross_attention_kwargs,
1246
+ added_cond_kwargs=added_cond_kwargs,
1247
+ return_dict=False,
1248
+ )[0]
1249
+
1250
+ # perform guidance
1251
+ if self.do_classifier_free_guidance:
1252
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1253
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1254
+
1255
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1256
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
1257
+ noise_pred = rescale_noise_cfg(
1258
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1259
+ )
1260
+
1261
+ # compute the previous noisy sample x_t -> x_t-1
1262
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1263
+
1264
+ if callback_on_step_end is not None:
1265
+ callback_kwargs = {}
1266
+ for k in callback_on_step_end_tensor_inputs:
1267
+ callback_kwargs[k] = locals()[k]
1268
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1269
+
1270
+ latents = callback_outputs.pop("latents", latents)
1271
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1272
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1273
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1274
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1275
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1276
+ )
1277
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1278
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1279
+
1280
+ progress_bar.update()
1281
+
1282
+ if XLA_AVAILABLE:
1283
+ xm.mark_step()
1284
+
1285
+ # make sure the VAE is in float32 mode, as it overflows in float16
1286
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1287
+
1288
+ if needs_upcasting:
1289
+ self.upcast_vae()
1290
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1291
+
1292
+ # 10. Post processing
1293
+ if output_type == "latent":
1294
+ video = latents
1295
+ else:
1296
+ video_tensor = self.decode_latents(latents)
1297
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1298
+
1299
+ # cast back to fp16 if needed
1300
+ if needs_upcasting:
1301
+ self.vae.to(dtype=torch.float16)
1302
+
1303
+ # 11. Offload all models
1304
+ self.maybe_free_model_hooks()
1305
+
1306
+ if not return_dict:
1307
+ return (video,)
1308
+
1309
+ return AnimateDiffPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
23
+
24
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
25
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
26
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
27
+ from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
28
+ from ...models.lora import adjust_lora_scale_text_encoder
29
+ from ...models.unets.unet_motion_model import MotionAdapter
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ is_torch_xla_available,
34
+ logging,
35
+ replace_example_docstring,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ...utils.torch_utils import is_compiled_module, randn_tensor
40
+ from ...video_processor import VideoProcessor
41
+ from ..free_init_utils import FreeInitMixin
42
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
43
+ from .pipeline_output import AnimateDiffPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```python
59
+ >>> import torch
60
+ >>> from diffusers import AnimateDiffSparseControlNetPipeline
61
+ >>> from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
62
+ >>> from diffusers.schedulers import DPMSolverMultistepScheduler
63
+ >>> from diffusers.utils import export_to_gif, load_image
64
+
65
+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
66
+ >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
67
+ >>> controlnet_id = "guoyww/animatediff-sparsectrl-scribble"
68
+ >>> lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3"
69
+ >>> vae_id = "stabilityai/sd-vae-ft-mse"
70
+ >>> device = "cuda"
71
+
72
+ >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
73
+ >>> controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
74
+ >>> vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
75
+ >>> scheduler = DPMSolverMultistepScheduler.from_pretrained(
76
+ ... model_id,
77
+ ... subfolder="scheduler",
78
+ ... beta_schedule="linear",
79
+ ... algorithm_type="dpmsolver++",
80
+ ... use_karras_sigmas=True,
81
+ ... )
82
+ >>> pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
83
+ ... model_id,
84
+ ... motion_adapter=motion_adapter,
85
+ ... controlnet=controlnet,
86
+ ... vae=vae,
87
+ ... scheduler=scheduler,
88
+ ... torch_dtype=torch.float16,
89
+ ... ).to(device)
90
+ >>> pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
91
+ >>> pipe.fuse_lora(lora_scale=1.0)
92
+
93
+ >>> prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
94
+ >>> negative_prompt = "low quality, worst quality, letterboxed"
95
+
96
+ >>> image_files = [
97
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
98
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
99
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png",
100
+ ... ]
101
+ >>> condition_frame_indices = [0, 8, 15]
102
+ >>> conditioning_frames = [load_image(img_file) for img_file in image_files]
103
+
104
+ >>> video = pipe(
105
+ ... prompt=prompt,
106
+ ... negative_prompt=negative_prompt,
107
+ ... num_inference_steps=25,
108
+ ... conditioning_frames=conditioning_frames,
109
+ ... controlnet_conditioning_scale=1.0,
110
+ ... controlnet_frame_indices=condition_frame_indices,
111
+ ... generator=torch.Generator().manual_seed(1337),
112
+ ... ).frames[0]
113
+ >>> export_to_gif(video, "output.gif")
114
+ ```
115
+ """
116
+
117
+
118
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
119
+ def retrieve_latents(
120
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
121
+ ):
122
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
123
+ return encoder_output.latent_dist.sample(generator)
124
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
125
+ return encoder_output.latent_dist.mode()
126
+ elif hasattr(encoder_output, "latents"):
127
+ return encoder_output.latents
128
+ else:
129
+ raise AttributeError("Could not access latents of provided encoder_output")
130
+
131
+
132
+ class AnimateDiffSparseControlNetPipeline(
133
+ DiffusionPipeline,
134
+ StableDiffusionMixin,
135
+ TextualInversionLoaderMixin,
136
+ IPAdapterMixin,
137
+ StableDiffusionLoraLoaderMixin,
138
+ FreeInitMixin,
139
+ FromSingleFileMixin,
140
+ ):
141
+ r"""
142
+ Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
143
+ to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933).
144
+
145
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
146
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
147
+
148
+ The pipeline also inherits the following loading methods:
149
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
150
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
151
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
152
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
153
+
154
+ Args:
155
+ vae ([`AutoencoderKL`]):
156
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
157
+ text_encoder ([`CLIPTextModel`]):
158
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
159
+ tokenizer (`CLIPTokenizer`):
160
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
161
+ unet ([`UNet2DConditionModel`]):
162
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
163
+ motion_adapter ([`MotionAdapter`]):
164
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
165
+ scheduler ([`SchedulerMixin`]):
166
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
167
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
168
+ """
169
+
170
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
171
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
172
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
173
+
174
+ def __init__(
175
+ self,
176
+ vae: AutoencoderKL,
177
+ text_encoder: CLIPTextModel,
178
+ tokenizer: CLIPTokenizer,
179
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
180
+ motion_adapter: MotionAdapter,
181
+ controlnet: SparseControlNetModel,
182
+ scheduler: KarrasDiffusionSchedulers,
183
+ feature_extractor: CLIPImageProcessor = None,
184
+ image_encoder: CLIPVisionModelWithProjection = None,
185
+ ):
186
+ super().__init__()
187
+ if isinstance(unet, UNet2DConditionModel):
188
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
189
+
190
+ self.register_modules(
191
+ vae=vae,
192
+ text_encoder=text_encoder,
193
+ tokenizer=tokenizer,
194
+ unet=unet,
195
+ motion_adapter=motion_adapter,
196
+ controlnet=controlnet,
197
+ scheduler=scheduler,
198
+ feature_extractor=feature_extractor,
199
+ image_encoder=image_encoder,
200
+ )
201
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
202
+ self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
203
+ self.control_image_processor = VaeImageProcessor(
204
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
205
+ )
206
+
207
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
208
+ def encode_prompt(
209
+ self,
210
+ prompt,
211
+ device,
212
+ num_images_per_prompt,
213
+ do_classifier_free_guidance,
214
+ negative_prompt=None,
215
+ prompt_embeds: Optional[torch.Tensor] = None,
216
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
217
+ lora_scale: Optional[float] = None,
218
+ clip_skip: Optional[int] = None,
219
+ ):
220
+ r"""
221
+ Encodes the prompt into text encoder hidden states.
222
+
223
+ Args:
224
+ prompt (`str` or `List[str]`, *optional*):
225
+ prompt to be encoded
226
+ device: (`torch.device`):
227
+ torch device
228
+ num_images_per_prompt (`int`):
229
+ number of images that should be generated per prompt
230
+ do_classifier_free_guidance (`bool`):
231
+ whether to use classifier free guidance or not
232
+ negative_prompt (`str` or `List[str]`, *optional*):
233
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
234
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
235
+ less than `1`).
236
+ prompt_embeds (`torch.Tensor`, *optional*):
237
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
238
+ provided, text embeddings will be generated from `prompt` input argument.
239
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
240
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
241
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
242
+ argument.
243
+ lora_scale (`float`, *optional*):
244
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
245
+ clip_skip (`int`, *optional*):
246
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
247
+ the output of the pre-final layer will be used for computing the prompt embeddings.
248
+ """
249
+ # set lora scale so that monkey patched LoRA
250
+ # function of text encoder can correctly access it
251
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
252
+ self._lora_scale = lora_scale
253
+
254
+ # dynamically adjust the LoRA scale
255
+ if not USE_PEFT_BACKEND:
256
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
257
+ else:
258
+ scale_lora_layers(self.text_encoder, lora_scale)
259
+
260
+ if prompt is not None and isinstance(prompt, str):
261
+ batch_size = 1
262
+ elif prompt is not None and isinstance(prompt, list):
263
+ batch_size = len(prompt)
264
+ else:
265
+ batch_size = prompt_embeds.shape[0]
266
+
267
+ if prompt_embeds is None:
268
+ # textual inversion: process multi-vector tokens if necessary
269
+ if isinstance(self, TextualInversionLoaderMixin):
270
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
271
+
272
+ text_inputs = self.tokenizer(
273
+ prompt,
274
+ padding="max_length",
275
+ max_length=self.tokenizer.model_max_length,
276
+ truncation=True,
277
+ return_tensors="pt",
278
+ )
279
+ text_input_ids = text_inputs.input_ids
280
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
281
+
282
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
283
+ text_input_ids, untruncated_ids
284
+ ):
285
+ removed_text = self.tokenizer.batch_decode(
286
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
287
+ )
288
+ logger.warning(
289
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
290
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
291
+ )
292
+
293
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
294
+ attention_mask = text_inputs.attention_mask.to(device)
295
+ else:
296
+ attention_mask = None
297
+
298
+ if clip_skip is None:
299
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
300
+ prompt_embeds = prompt_embeds[0]
301
+ else:
302
+ prompt_embeds = self.text_encoder(
303
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
304
+ )
305
+ # Access the `hidden_states` first, that contains a tuple of
306
+ # all the hidden states from the encoder layers. Then index into
307
+ # the tuple to access the hidden states from the desired layer.
308
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
309
+ # We also need to apply the final LayerNorm here to not mess with the
310
+ # representations. The `last_hidden_states` that we typically use for
311
+ # obtaining the final prompt representations passes through the LayerNorm
312
+ # layer.
313
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
314
+
315
+ if self.text_encoder is not None:
316
+ prompt_embeds_dtype = self.text_encoder.dtype
317
+ elif self.unet is not None:
318
+ prompt_embeds_dtype = self.unet.dtype
319
+ else:
320
+ prompt_embeds_dtype = prompt_embeds.dtype
321
+
322
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
323
+
324
+ bs_embed, seq_len, _ = prompt_embeds.shape
325
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
326
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
328
+
329
+ # get unconditional embeddings for classifier free guidance
330
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
331
+ uncond_tokens: List[str]
332
+ if negative_prompt is None:
333
+ uncond_tokens = [""] * batch_size
334
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
335
+ raise TypeError(
336
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
337
+ f" {type(prompt)}."
338
+ )
339
+ elif isinstance(negative_prompt, str):
340
+ uncond_tokens = [negative_prompt]
341
+ elif batch_size != len(negative_prompt):
342
+ raise ValueError(
343
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
344
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
345
+ " the batch size of `prompt`."
346
+ )
347
+ else:
348
+ uncond_tokens = negative_prompt
349
+
350
+ # textual inversion: process multi-vector tokens if necessary
351
+ if isinstance(self, TextualInversionLoaderMixin):
352
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
353
+
354
+ max_length = prompt_embeds.shape[1]
355
+ uncond_input = self.tokenizer(
356
+ uncond_tokens,
357
+ padding="max_length",
358
+ max_length=max_length,
359
+ truncation=True,
360
+ return_tensors="pt",
361
+ )
362
+
363
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
364
+ attention_mask = uncond_input.attention_mask.to(device)
365
+ else:
366
+ attention_mask = None
367
+
368
+ negative_prompt_embeds = self.text_encoder(
369
+ uncond_input.input_ids.to(device),
370
+ attention_mask=attention_mask,
371
+ )
372
+ negative_prompt_embeds = negative_prompt_embeds[0]
373
+
374
+ if do_classifier_free_guidance:
375
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
376
+ seq_len = negative_prompt_embeds.shape[1]
377
+
378
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
379
+
380
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
381
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
382
+
383
+ if self.text_encoder is not None:
384
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
385
+ # Retrieve the original scale by scaling back the LoRA layers
386
+ unscale_lora_layers(self.text_encoder, lora_scale)
387
+
388
+ return prompt_embeds, negative_prompt_embeds
389
+
390
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
391
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
392
+ dtype = next(self.image_encoder.parameters()).dtype
393
+
394
+ if not isinstance(image, torch.Tensor):
395
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
396
+
397
+ image = image.to(device=device, dtype=dtype)
398
+ if output_hidden_states:
399
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
400
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
401
+ uncond_image_enc_hidden_states = self.image_encoder(
402
+ torch.zeros_like(image), output_hidden_states=True
403
+ ).hidden_states[-2]
404
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
405
+ num_images_per_prompt, dim=0
406
+ )
407
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
408
+ else:
409
+ image_embeds = self.image_encoder(image).image_embeds
410
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
411
+ uncond_image_embeds = torch.zeros_like(image_embeds)
412
+
413
+ return image_embeds, uncond_image_embeds
414
+
415
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
416
+ def prepare_ip_adapter_image_embeds(
417
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
418
+ ):
419
+ image_embeds = []
420
+ if do_classifier_free_guidance:
421
+ negative_image_embeds = []
422
+ if ip_adapter_image_embeds is None:
423
+ if not isinstance(ip_adapter_image, list):
424
+ ip_adapter_image = [ip_adapter_image]
425
+
426
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
427
+ raise ValueError(
428
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
429
+ )
430
+
431
+ for single_ip_adapter_image, image_proj_layer in zip(
432
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
433
+ ):
434
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
435
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
436
+ single_ip_adapter_image, device, 1, output_hidden_state
437
+ )
438
+
439
+ image_embeds.append(single_image_embeds[None, :])
440
+ if do_classifier_free_guidance:
441
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
442
+ else:
443
+ for single_image_embeds in ip_adapter_image_embeds:
444
+ if do_classifier_free_guidance:
445
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
446
+ negative_image_embeds.append(single_negative_image_embeds)
447
+ image_embeds.append(single_image_embeds)
448
+
449
+ ip_adapter_image_embeds = []
450
+ for i, single_image_embeds in enumerate(image_embeds):
451
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
452
+ if do_classifier_free_guidance:
453
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
454
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
455
+
456
+ single_image_embeds = single_image_embeds.to(device=device)
457
+ ip_adapter_image_embeds.append(single_image_embeds)
458
+
459
+ return ip_adapter_image_embeds
460
+
461
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
462
+ def decode_latents(self, latents):
463
+ latents = 1 / self.vae.config.scaling_factor * latents
464
+
465
+ batch_size, channels, num_frames, height, width = latents.shape
466
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
467
+
468
+ image = self.vae.decode(latents).sample
469
+ video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
470
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
471
+ video = video.float()
472
+ return video
473
+
474
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
475
+ def prepare_extra_step_kwargs(self, generator, eta):
476
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
477
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
478
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
479
+ # and should be between [0, 1]
480
+
481
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
482
+ extra_step_kwargs = {}
483
+ if accepts_eta:
484
+ extra_step_kwargs["eta"] = eta
485
+
486
+ # check if the scheduler accepts generator
487
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
488
+ if accepts_generator:
489
+ extra_step_kwargs["generator"] = generator
490
+ return extra_step_kwargs
491
+
492
+ def check_inputs(
493
+ self,
494
+ prompt,
495
+ height,
496
+ width,
497
+ negative_prompt=None,
498
+ prompt_embeds=None,
499
+ negative_prompt_embeds=None,
500
+ ip_adapter_image=None,
501
+ ip_adapter_image_embeds=None,
502
+ callback_on_step_end_tensor_inputs=None,
503
+ image=None,
504
+ controlnet_conditioning_scale: float = 1.0,
505
+ ):
506
+ if height % 8 != 0 or width % 8 != 0:
507
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
508
+
509
+ if callback_on_step_end_tensor_inputs is not None and not all(
510
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
511
+ ):
512
+ raise ValueError(
513
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
514
+ )
515
+
516
+ if prompt is not None and prompt_embeds is not None:
517
+ raise ValueError(
518
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
519
+ " only forward one of the two."
520
+ )
521
+ elif prompt is None and prompt_embeds is None:
522
+ raise ValueError(
523
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
524
+ )
525
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
526
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
527
+
528
+ if negative_prompt is not None and negative_prompt_embeds is not None:
529
+ raise ValueError(
530
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
531
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
532
+ )
533
+
534
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
535
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
536
+ raise ValueError(
537
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
538
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
539
+ f" {negative_prompt_embeds.shape}."
540
+ )
541
+
542
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
543
+ raise ValueError(
544
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
545
+ )
546
+
547
+ if ip_adapter_image_embeds is not None:
548
+ if not isinstance(ip_adapter_image_embeds, list):
549
+ raise ValueError(
550
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
551
+ )
552
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
553
+ raise ValueError(
554
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
555
+ )
556
+
557
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
558
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
559
+ )
560
+
561
+ # check `image`
562
+ if (
563
+ isinstance(self.controlnet, SparseControlNetModel)
564
+ or is_compiled
565
+ and isinstance(self.controlnet._orig_mod, SparseControlNetModel)
566
+ ):
567
+ if isinstance(image, list):
568
+ for image_ in image:
569
+ self.check_image(image_, prompt, prompt_embeds)
570
+ else:
571
+ self.check_image(image, prompt, prompt_embeds)
572
+ else:
573
+ assert False
574
+
575
+ # Check `controlnet_conditioning_scale`
576
+ if (
577
+ isinstance(self.controlnet, SparseControlNetModel)
578
+ or is_compiled
579
+ and isinstance(self.controlnet._orig_mod, SparseControlNetModel)
580
+ ):
581
+ if not isinstance(controlnet_conditioning_scale, float):
582
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
583
+ else:
584
+ assert False
585
+
586
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
587
+ def check_image(self, image, prompt, prompt_embeds):
588
+ image_is_pil = isinstance(image, PIL.Image.Image)
589
+ image_is_tensor = isinstance(image, torch.Tensor)
590
+ image_is_np = isinstance(image, np.ndarray)
591
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
592
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
593
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
594
+
595
+ if (
596
+ not image_is_pil
597
+ and not image_is_tensor
598
+ and not image_is_np
599
+ and not image_is_pil_list
600
+ and not image_is_tensor_list
601
+ and not image_is_np_list
602
+ ):
603
+ raise TypeError(
604
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
605
+ )
606
+
607
+ if image_is_pil:
608
+ image_batch_size = 1
609
+ else:
610
+ image_batch_size = len(image)
611
+
612
+ if prompt is not None and isinstance(prompt, str):
613
+ prompt_batch_size = 1
614
+ elif prompt is not None and isinstance(prompt, list):
615
+ prompt_batch_size = len(prompt)
616
+ elif prompt_embeds is not None:
617
+ prompt_batch_size = prompt_embeds.shape[0]
618
+
619
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
620
+ raise ValueError(
621
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
622
+ )
623
+
624
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
625
+ def prepare_latents(
626
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
627
+ ):
628
+ shape = (
629
+ batch_size,
630
+ num_channels_latents,
631
+ num_frames,
632
+ height // self.vae_scale_factor,
633
+ width // self.vae_scale_factor,
634
+ )
635
+ if isinstance(generator, list) and len(generator) != batch_size:
636
+ raise ValueError(
637
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
638
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
639
+ )
640
+
641
+ if latents is None:
642
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
643
+ else:
644
+ latents = latents.to(device)
645
+
646
+ # scale the initial noise by the standard deviation required by the scheduler
647
+ latents = latents * self.scheduler.init_noise_sigma
648
+ return latents
649
+
650
+ def prepare_image(self, image, width, height, device, dtype):
651
+ image = self.control_image_processor.preprocess(image, height=height, width=width)
652
+ controlnet_images = image.unsqueeze(0).to(device, dtype)
653
+ batch_size, num_frames, channels, height, width = controlnet_images.shape
654
+
655
+ # TODO: remove below line
656
+ assert controlnet_images.min() >= 0 and controlnet_images.max() <= 1
657
+
658
+ if self.controlnet.use_simplified_condition_embedding:
659
+ controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width)
660
+ controlnet_images = 2 * controlnet_images - 1
661
+ conditioning_frames = retrieve_latents(self.vae.encode(controlnet_images)) * self.vae.config.scaling_factor
662
+ conditioning_frames = conditioning_frames.reshape(
663
+ batch_size, num_frames, 4, height // self.vae_scale_factor, width // self.vae_scale_factor
664
+ )
665
+ else:
666
+ conditioning_frames = controlnet_images
667
+
668
+ conditioning_frames = conditioning_frames.permute(0, 2, 1, 3, 4) # [b, c, f, h, w]
669
+ return conditioning_frames
670
+
671
+ def prepare_sparse_control_conditioning(
672
+ self,
673
+ conditioning_frames: torch.Tensor,
674
+ num_frames: int,
675
+ controlnet_frame_indices: int,
676
+ device: torch.device,
677
+ dtype: torch.dtype,
678
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
679
+ assert conditioning_frames.shape[2] >= len(controlnet_frame_indices)
680
+
681
+ batch_size, channels, _, height, width = conditioning_frames.shape
682
+ controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device)
683
+ controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device)
684
+ controlnet_cond[:, :, controlnet_frame_indices] = conditioning_frames[:, :, : len(controlnet_frame_indices)]
685
+ controlnet_cond_mask[:, :, controlnet_frame_indices] = 1
686
+
687
+ return controlnet_cond, controlnet_cond_mask
688
+
689
+ @property
690
+ def guidance_scale(self):
691
+ return self._guidance_scale
692
+
693
+ @property
694
+ def clip_skip(self):
695
+ return self._clip_skip
696
+
697
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
698
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
699
+ # corresponds to doing no classifier free guidance.
700
+ @property
701
+ def do_classifier_free_guidance(self):
702
+ return self._guidance_scale > 1
703
+
704
+ @property
705
+ def cross_attention_kwargs(self):
706
+ return self._cross_attention_kwargs
707
+
708
+ @property
709
+ def num_timesteps(self):
710
+ return self._num_timesteps
711
+
712
+ @torch.no_grad()
713
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
714
+ def __call__(
715
+ self,
716
+ prompt: Optional[Union[str, List[str]]] = None,
717
+ height: Optional[int] = None,
718
+ width: Optional[int] = None,
719
+ num_frames: int = 16,
720
+ num_inference_steps: int = 50,
721
+ guidance_scale: float = 7.5,
722
+ negative_prompt: Optional[Union[str, List[str]]] = None,
723
+ num_videos_per_prompt: int = 1,
724
+ eta: float = 0.0,
725
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
726
+ latents: Optional[torch.Tensor] = None,
727
+ prompt_embeds: Optional[torch.Tensor] = None,
728
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
729
+ ip_adapter_image: Optional[PipelineImageInput] = None,
730
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
731
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
732
+ output_type: str = "pil",
733
+ return_dict: bool = True,
734
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
735
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
736
+ controlnet_frame_indices: List[int] = [0],
737
+ guess_mode: bool = False,
738
+ clip_skip: Optional[int] = None,
739
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
740
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
741
+ ):
742
+ r"""
743
+ The call function to the pipeline for generation.
744
+
745
+ Args:
746
+ prompt (`str` or `List[str]`, *optional*):
747
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
748
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
749
+ The height in pixels of the generated video.
750
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
751
+ The width in pixels of the generated video.
752
+ num_frames (`int`, *optional*, defaults to 16):
753
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
754
+ amounts to 2 seconds of video.
755
+ num_inference_steps (`int`, *optional*, defaults to 50):
756
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
757
+ expense of slower inference.
758
+ guidance_scale (`float`, *optional*, defaults to 7.5):
759
+ A higher guidance scale value encourages the model to generate images closely linked to the text
760
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
761
+ negative_prompt (`str` or `List[str]`, *optional*):
762
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
763
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
764
+ eta (`float`, *optional*, defaults to 0.0):
765
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
766
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
767
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
768
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
769
+ generation deterministic.
770
+ latents (`torch.Tensor`, *optional*):
771
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
772
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
773
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
774
+ `(batch_size, num_channel, num_frames, height, width)`.
775
+ prompt_embeds (`torch.Tensor`, *optional*):
776
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
777
+ provided, text embeddings are generated from the `prompt` input argument.
778
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
779
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
780
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
781
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
782
+ Optional image input to work with IP Adapters.
783
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
784
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
785
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
786
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
787
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
788
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
789
+ The SparseControlNet input to provide guidance to the `unet` for generation.
790
+ output_type (`str`, *optional*, defaults to `"pil"`):
791
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
792
+ return_dict (`bool`, *optional*, defaults to `True`):
793
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
794
+ of a plain tuple.
795
+ cross_attention_kwargs (`dict`, *optional*):
796
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
797
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
798
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
799
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
800
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
801
+ the corresponding scale as a list.
802
+ controlnet_frame_indices (`List[int]`):
803
+ The indices where the conditioning frames must be applied for generation. Multiple frames can be
804
+ provided to guide the model to generate similar structure outputs, where the `unet` can
805
+ "fill-in-the-gaps" for interpolation videos, or a single frame could be provided for general expected
806
+ structure. Must have the same length as `conditioning_frames`.
807
+ clip_skip (`int`, *optional*):
808
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
809
+ the output of the pre-final layer will be used for computing the prompt embeddings.
810
+ callback_on_step_end (`Callable`, *optional*):
811
+ A function that calls at the end of each denoising steps during the inference. The function is called
812
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
813
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
814
+ `callback_on_step_end_tensor_inputs`.
815
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
816
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
817
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
818
+ `._callback_tensor_inputs` attribute of your pipeline class.
819
+
820
+ Examples:
821
+
822
+ Returns:
823
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
824
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
825
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
826
+ """
827
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
828
+
829
+ # 0. Default height and width to unet
830
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
831
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
832
+ num_videos_per_prompt = 1
833
+
834
+ # 1. Check inputs. Raise error if not correct
835
+ self.check_inputs(
836
+ prompt=prompt,
837
+ height=height,
838
+ width=width,
839
+ negative_prompt=negative_prompt,
840
+ prompt_embeds=prompt_embeds,
841
+ negative_prompt_embeds=negative_prompt_embeds,
842
+ ip_adapter_image=ip_adapter_image,
843
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
844
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
845
+ image=conditioning_frames,
846
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
847
+ )
848
+
849
+ self._guidance_scale = guidance_scale
850
+ self._clip_skip = clip_skip
851
+ self._cross_attention_kwargs = cross_attention_kwargs
852
+
853
+ # 2. Define call parameters
854
+ if prompt is not None and isinstance(prompt, str):
855
+ batch_size = 1
856
+ elif prompt is not None and isinstance(prompt, list):
857
+ batch_size = len(prompt)
858
+ else:
859
+ batch_size = prompt_embeds.shape[0]
860
+
861
+ device = self._execution_device
862
+
863
+ global_pool_conditions = (
864
+ controlnet.config.global_pool_conditions
865
+ if isinstance(controlnet, SparseControlNetModel)
866
+ else controlnet.nets[0].config.global_pool_conditions
867
+ )
868
+ guess_mode = guess_mode or global_pool_conditions
869
+
870
+ # 3. Encode input prompt
871
+ text_encoder_lora_scale = (
872
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
873
+ )
874
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
875
+ prompt,
876
+ device,
877
+ num_videos_per_prompt,
878
+ self.do_classifier_free_guidance,
879
+ negative_prompt,
880
+ prompt_embeds=prompt_embeds,
881
+ negative_prompt_embeds=negative_prompt_embeds,
882
+ lora_scale=text_encoder_lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ )
885
+ # For classifier free guidance, we need to do two forward passes.
886
+ # Here we concatenate the unconditional and text embeddings into a single batch
887
+ # to avoid doing two forward passes
888
+ if self.do_classifier_free_guidance:
889
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
890
+
891
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
892
+
893
+ # 4. Prepare IP-Adapter embeddings
894
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
895
+ image_embeds = self.prepare_ip_adapter_image_embeds(
896
+ ip_adapter_image,
897
+ ip_adapter_image_embeds,
898
+ device,
899
+ batch_size * num_videos_per_prompt,
900
+ self.do_classifier_free_guidance,
901
+ )
902
+
903
+ # 5. Prepare controlnet conditioning
904
+ conditioning_frames = self.prepare_image(conditioning_frames, width, height, device, controlnet.dtype)
905
+ controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning(
906
+ conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype
907
+ )
908
+
909
+ # 6. Prepare timesteps
910
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
911
+ timesteps = self.scheduler.timesteps
912
+
913
+ # 7. Prepare latent variables
914
+ num_channels_latents = self.unet.config.in_channels
915
+ latents = self.prepare_latents(
916
+ batch_size * num_videos_per_prompt,
917
+ num_channels_latents,
918
+ num_frames,
919
+ height,
920
+ width,
921
+ prompt_embeds.dtype,
922
+ device,
923
+ generator,
924
+ latents,
925
+ )
926
+
927
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
928
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
929
+
930
+ # 9. Add image embeds for IP-Adapter
931
+ added_cond_kwargs = (
932
+ {"image_embeds": image_embeds}
933
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
934
+ else None
935
+ )
936
+
937
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
938
+ for free_init_iter in range(num_free_init_iters):
939
+ if self.free_init_enabled:
940
+ latents, timesteps = self._apply_free_init(
941
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
942
+ )
943
+
944
+ self._num_timesteps = len(timesteps)
945
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
946
+
947
+ # 10. Denoising loop
948
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
949
+ for i, t in enumerate(timesteps):
950
+ # expand the latents if we are doing classifier free guidance
951
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
952
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
953
+
954
+ if guess_mode and self.do_classifier_free_guidance:
955
+ # Infer SparseControlNetModel only for the conditional batch.
956
+ control_model_input = latents
957
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
958
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
959
+ else:
960
+ control_model_input = latent_model_input
961
+ controlnet_prompt_embeds = prompt_embeds
962
+
963
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
964
+ control_model_input,
965
+ t,
966
+ encoder_hidden_states=controlnet_prompt_embeds,
967
+ controlnet_cond=controlnet_cond,
968
+ conditioning_mask=controlnet_cond_mask,
969
+ conditioning_scale=controlnet_conditioning_scale,
970
+ guess_mode=guess_mode,
971
+ return_dict=False,
972
+ )
973
+
974
+ # predict the noise residual
975
+ noise_pred = self.unet(
976
+ latent_model_input,
977
+ t,
978
+ encoder_hidden_states=prompt_embeds,
979
+ cross_attention_kwargs=cross_attention_kwargs,
980
+ added_cond_kwargs=added_cond_kwargs,
981
+ down_block_additional_residuals=down_block_res_samples,
982
+ mid_block_additional_residual=mid_block_res_sample,
983
+ ).sample
984
+
985
+ # perform guidance
986
+ if self.do_classifier_free_guidance:
987
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
988
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
989
+
990
+ # compute the previous noisy sample x_t -> x_t-1
991
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
992
+
993
+ if callback_on_step_end is not None:
994
+ callback_kwargs = {}
995
+ for k in callback_on_step_end_tensor_inputs:
996
+ callback_kwargs[k] = locals()[k]
997
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
998
+
999
+ latents = callback_outputs.pop("latents", latents)
1000
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1001
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1002
+
1003
+ # call the callback, if provided
1004
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1005
+ progress_bar.update()
1006
+
1007
+ if XLA_AVAILABLE:
1008
+ xm.mark_step()
1009
+
1010
+ # 11. Post processing
1011
+ if output_type == "latent":
1012
+ video = latents
1013
+ else:
1014
+ video_tensor = self.decode_latents(latents)
1015
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1016
+
1017
+ # 12. Offload all models
1018
+ self.maybe_free_model_hooks()
1019
+
1020
+ if not return_dict:
1021
+ return (video,)
1022
+
1023
+ return AnimateDiffPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
20
+
21
+ from ...image_processor import PipelineImageInput
22
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
23
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
24
+ from ...models.lora import adjust_lora_scale_text_encoder
25
+ from ...models.unets.unet_motion_model import MotionAdapter
26
+ from ...schedulers import (
27
+ DDIMScheduler,
28
+ DPMSolverMultistepScheduler,
29
+ EulerAncestralDiscreteScheduler,
30
+ EulerDiscreteScheduler,
31
+ LMSDiscreteScheduler,
32
+ PNDMScheduler,
33
+ )
34
+ from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
35
+ from ...utils.torch_utils import randn_tensor
36
+ from ...video_processor import VideoProcessor
37
+ from ..free_init_utils import FreeInitMixin
38
+ from ..free_noise_utils import AnimateDiffFreeNoiseMixin
39
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
+ from .pipeline_output import AnimateDiffPipelineOutput
41
+
42
+
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> import imageio
57
+ >>> import requests
58
+ >>> import torch
59
+ >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
60
+ >>> from diffusers.utils import export_to_gif
61
+ >>> from io import BytesIO
62
+ >>> from PIL import Image
63
+
64
+ >>> adapter = MotionAdapter.from_pretrained(
65
+ ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
66
+ ... )
67
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
68
+ ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
69
+ ... ).to("cuda")
70
+ >>> pipe.scheduler = DDIMScheduler(
71
+ ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
72
+ ... )
73
+
74
+
75
+ >>> def load_video(file_path: str):
76
+ ... images = []
77
+
78
+ ... if file_path.startswith(("http://", "https://")):
79
+ ... # If the file_path is a URL
80
+ ... response = requests.get(file_path)
81
+ ... response.raise_for_status()
82
+ ... content = BytesIO(response.content)
83
+ ... vid = imageio.get_reader(content)
84
+ ... else:
85
+ ... # Assuming it's a local file path
86
+ ... vid = imageio.get_reader(file_path)
87
+
88
+ ... for frame in vid:
89
+ ... pil_image = Image.fromarray(frame)
90
+ ... images.append(pil_image)
91
+
92
+ ... return images
93
+
94
+
95
+ >>> video = load_video(
96
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
97
+ ... )
98
+ >>> output = pipe(
99
+ ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
100
+ ... )
101
+ >>> frames = output.frames[0]
102
+ >>> export_to_gif(frames, "animation.gif")
103
+ ```
104
+ """
105
+
106
+
107
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
108
+ def retrieve_latents(
109
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
110
+ ):
111
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
112
+ return encoder_output.latent_dist.sample(generator)
113
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
114
+ return encoder_output.latent_dist.mode()
115
+ elif hasattr(encoder_output, "latents"):
116
+ return encoder_output.latents
117
+ else:
118
+ raise AttributeError("Could not access latents of provided encoder_output")
119
+
120
+
121
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
122
+ def retrieve_timesteps(
123
+ scheduler,
124
+ num_inference_steps: Optional[int] = None,
125
+ device: Optional[Union[str, torch.device]] = None,
126
+ timesteps: Optional[List[int]] = None,
127
+ sigmas: Optional[List[float]] = None,
128
+ **kwargs,
129
+ ):
130
+ r"""
131
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
132
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
133
+
134
+ Args:
135
+ scheduler (`SchedulerMixin`):
136
+ The scheduler to get timesteps from.
137
+ num_inference_steps (`int`):
138
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
139
+ must be `None`.
140
+ device (`str` or `torch.device`, *optional*):
141
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
142
+ timesteps (`List[int]`, *optional*):
143
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
144
+ `num_inference_steps` and `sigmas` must be `None`.
145
+ sigmas (`List[float]`, *optional*):
146
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
147
+ `num_inference_steps` and `timesteps` must be `None`.
148
+
149
+ Returns:
150
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
151
+ second element is the number of inference steps.
152
+ """
153
+ if timesteps is not None and sigmas is not None:
154
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
155
+ if timesteps is not None:
156
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
157
+ if not accepts_timesteps:
158
+ raise ValueError(
159
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
160
+ f" timestep schedules. Please check whether you are using the correct scheduler."
161
+ )
162
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
163
+ timesteps = scheduler.timesteps
164
+ num_inference_steps = len(timesteps)
165
+ elif sigmas is not None:
166
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
167
+ if not accept_sigmas:
168
+ raise ValueError(
169
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
170
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
171
+ )
172
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
173
+ timesteps = scheduler.timesteps
174
+ num_inference_steps = len(timesteps)
175
+ else:
176
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
177
+ timesteps = scheduler.timesteps
178
+ return timesteps, num_inference_steps
179
+
180
+
181
+ class AnimateDiffVideoToVideoPipeline(
182
+ DiffusionPipeline,
183
+ StableDiffusionMixin,
184
+ TextualInversionLoaderMixin,
185
+ IPAdapterMixin,
186
+ StableDiffusionLoraLoaderMixin,
187
+ FreeInitMixin,
188
+ AnimateDiffFreeNoiseMixin,
189
+ FromSingleFileMixin,
190
+ ):
191
+ r"""
192
+ Pipeline for video-to-video generation.
193
+
194
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
195
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
196
+
197
+ The pipeline also inherits the following loading methods:
198
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
199
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
200
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
201
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
202
+
203
+ Args:
204
+ vae ([`AutoencoderKL`]):
205
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
206
+ text_encoder ([`CLIPTextModel`]):
207
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
208
+ tokenizer (`CLIPTokenizer`):
209
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
210
+ unet ([`UNet2DConditionModel`]):
211
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
212
+ motion_adapter ([`MotionAdapter`]):
213
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
214
+ scheduler ([`SchedulerMixin`]):
215
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
216
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
217
+ """
218
+
219
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
220
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
221
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
222
+
223
+ def __init__(
224
+ self,
225
+ vae: AutoencoderKL,
226
+ text_encoder: CLIPTextModel,
227
+ tokenizer: CLIPTokenizer,
228
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
229
+ motion_adapter: MotionAdapter,
230
+ scheduler: Union[
231
+ DDIMScheduler,
232
+ PNDMScheduler,
233
+ LMSDiscreteScheduler,
234
+ EulerDiscreteScheduler,
235
+ EulerAncestralDiscreteScheduler,
236
+ DPMSolverMultistepScheduler,
237
+ ],
238
+ feature_extractor: CLIPImageProcessor = None,
239
+ image_encoder: CLIPVisionModelWithProjection = None,
240
+ ):
241
+ super().__init__()
242
+ if isinstance(unet, UNet2DConditionModel):
243
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
244
+
245
+ self.register_modules(
246
+ vae=vae,
247
+ text_encoder=text_encoder,
248
+ tokenizer=tokenizer,
249
+ unet=unet,
250
+ motion_adapter=motion_adapter,
251
+ scheduler=scheduler,
252
+ feature_extractor=feature_extractor,
253
+ image_encoder=image_encoder,
254
+ )
255
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
256
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
257
+
258
+ def encode_prompt(
259
+ self,
260
+ prompt,
261
+ device,
262
+ num_images_per_prompt,
263
+ do_classifier_free_guidance,
264
+ negative_prompt=None,
265
+ prompt_embeds: Optional[torch.Tensor] = None,
266
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
267
+ lora_scale: Optional[float] = None,
268
+ clip_skip: Optional[int] = None,
269
+ ):
270
+ r"""
271
+ Encodes the prompt into text encoder hidden states.
272
+
273
+ Args:
274
+ prompt (`str` or `List[str]`, *optional*):
275
+ prompt to be encoded
276
+ device: (`torch.device`):
277
+ torch device
278
+ num_images_per_prompt (`int`):
279
+ number of images that should be generated per prompt
280
+ do_classifier_free_guidance (`bool`):
281
+ whether to use classifier free guidance or not
282
+ negative_prompt (`str` or `List[str]`, *optional*):
283
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
284
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
285
+ less than `1`).
286
+ prompt_embeds (`torch.Tensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ lora_scale (`float`, *optional*):
294
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
295
+ clip_skip (`int`, *optional*):
296
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
297
+ the output of the pre-final layer will be used for computing the prompt embeddings.
298
+ """
299
+ # set lora scale so that monkey patched LoRA
300
+ # function of text encoder can correctly access it
301
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
302
+ self._lora_scale = lora_scale
303
+
304
+ # dynamically adjust the LoRA scale
305
+ if not USE_PEFT_BACKEND:
306
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
307
+ else:
308
+ scale_lora_layers(self.text_encoder, lora_scale)
309
+
310
+ if prompt is not None and isinstance(prompt, (str, dict)):
311
+ batch_size = 1
312
+ elif prompt is not None and isinstance(prompt, list):
313
+ batch_size = len(prompt)
314
+ else:
315
+ batch_size = prompt_embeds.shape[0]
316
+
317
+ if prompt_embeds is None:
318
+ # textual inversion: process multi-vector tokens if necessary
319
+ if isinstance(self, TextualInversionLoaderMixin):
320
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
321
+
322
+ text_inputs = self.tokenizer(
323
+ prompt,
324
+ padding="max_length",
325
+ max_length=self.tokenizer.model_max_length,
326
+ truncation=True,
327
+ return_tensors="pt",
328
+ )
329
+ text_input_ids = text_inputs.input_ids
330
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
331
+
332
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
333
+ text_input_ids, untruncated_ids
334
+ ):
335
+ removed_text = self.tokenizer.batch_decode(
336
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
337
+ )
338
+ logger.warning(
339
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
340
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
341
+ )
342
+
343
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
344
+ attention_mask = text_inputs.attention_mask.to(device)
345
+ else:
346
+ attention_mask = None
347
+
348
+ if clip_skip is None:
349
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
350
+ prompt_embeds = prompt_embeds[0]
351
+ else:
352
+ prompt_embeds = self.text_encoder(
353
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
354
+ )
355
+ # Access the `hidden_states` first, that contains a tuple of
356
+ # all the hidden states from the encoder layers. Then index into
357
+ # the tuple to access the hidden states from the desired layer.
358
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
359
+ # We also need to apply the final LayerNorm here to not mess with the
360
+ # representations. The `last_hidden_states` that we typically use for
361
+ # obtaining the final prompt representations passes through the LayerNorm
362
+ # layer.
363
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
364
+
365
+ if self.text_encoder is not None:
366
+ prompt_embeds_dtype = self.text_encoder.dtype
367
+ elif self.unet is not None:
368
+ prompt_embeds_dtype = self.unet.dtype
369
+ else:
370
+ prompt_embeds_dtype = prompt_embeds.dtype
371
+
372
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
373
+
374
+ bs_embed, seq_len, _ = prompt_embeds.shape
375
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
376
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
377
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
378
+
379
+ # get unconditional embeddings for classifier free guidance
380
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
381
+ uncond_tokens: List[str]
382
+ if negative_prompt is None:
383
+ uncond_tokens = [""] * batch_size
384
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
385
+ raise TypeError(
386
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
387
+ f" {type(prompt)}."
388
+ )
389
+ elif isinstance(negative_prompt, str):
390
+ uncond_tokens = [negative_prompt]
391
+ elif batch_size != len(negative_prompt):
392
+ raise ValueError(
393
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
394
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
395
+ " the batch size of `prompt`."
396
+ )
397
+ else:
398
+ uncond_tokens = negative_prompt
399
+
400
+ # textual inversion: process multi-vector tokens if necessary
401
+ if isinstance(self, TextualInversionLoaderMixin):
402
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
403
+
404
+ max_length = prompt_embeds.shape[1]
405
+ uncond_input = self.tokenizer(
406
+ uncond_tokens,
407
+ padding="max_length",
408
+ max_length=max_length,
409
+ truncation=True,
410
+ return_tensors="pt",
411
+ )
412
+
413
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
414
+ attention_mask = uncond_input.attention_mask.to(device)
415
+ else:
416
+ attention_mask = None
417
+
418
+ negative_prompt_embeds = self.text_encoder(
419
+ uncond_input.input_ids.to(device),
420
+ attention_mask=attention_mask,
421
+ )
422
+ negative_prompt_embeds = negative_prompt_embeds[0]
423
+
424
+ if do_classifier_free_guidance:
425
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
426
+ seq_len = negative_prompt_embeds.shape[1]
427
+
428
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
429
+
430
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
431
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
432
+
433
+ if self.text_encoder is not None:
434
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
435
+ # Retrieve the original scale by scaling back the LoRA layers
436
+ unscale_lora_layers(self.text_encoder, lora_scale)
437
+
438
+ return prompt_embeds, negative_prompt_embeds
439
+
440
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
441
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
442
+ dtype = next(self.image_encoder.parameters()).dtype
443
+
444
+ if not isinstance(image, torch.Tensor):
445
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
446
+
447
+ image = image.to(device=device, dtype=dtype)
448
+ if output_hidden_states:
449
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
450
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
451
+ uncond_image_enc_hidden_states = self.image_encoder(
452
+ torch.zeros_like(image), output_hidden_states=True
453
+ ).hidden_states[-2]
454
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
455
+ num_images_per_prompt, dim=0
456
+ )
457
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
458
+ else:
459
+ image_embeds = self.image_encoder(image).image_embeds
460
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
461
+ uncond_image_embeds = torch.zeros_like(image_embeds)
462
+
463
+ return image_embeds, uncond_image_embeds
464
+
465
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
466
+ def prepare_ip_adapter_image_embeds(
467
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
468
+ ):
469
+ image_embeds = []
470
+ if do_classifier_free_guidance:
471
+ negative_image_embeds = []
472
+ if ip_adapter_image_embeds is None:
473
+ if not isinstance(ip_adapter_image, list):
474
+ ip_adapter_image = [ip_adapter_image]
475
+
476
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
477
+ raise ValueError(
478
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
479
+ )
480
+
481
+ for single_ip_adapter_image, image_proj_layer in zip(
482
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
483
+ ):
484
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
485
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
486
+ single_ip_adapter_image, device, 1, output_hidden_state
487
+ )
488
+
489
+ image_embeds.append(single_image_embeds[None, :])
490
+ if do_classifier_free_guidance:
491
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
492
+ else:
493
+ for single_image_embeds in ip_adapter_image_embeds:
494
+ if do_classifier_free_guidance:
495
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
496
+ negative_image_embeds.append(single_negative_image_embeds)
497
+ image_embeds.append(single_image_embeds)
498
+
499
+ ip_adapter_image_embeds = []
500
+ for i, single_image_embeds in enumerate(image_embeds):
501
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
502
+ if do_classifier_free_guidance:
503
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
504
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
505
+
506
+ single_image_embeds = single_image_embeds.to(device=device)
507
+ ip_adapter_image_embeds.append(single_image_embeds)
508
+
509
+ return ip_adapter_image_embeds
510
+
511
+ def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
512
+ latents = []
513
+ for i in range(0, len(video), decode_chunk_size):
514
+ batch_video = video[i : i + decode_chunk_size]
515
+ batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
516
+ latents.append(batch_video)
517
+ return torch.cat(latents)
518
+
519
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
520
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
521
+ latents = 1 / self.vae.config.scaling_factor * latents
522
+
523
+ batch_size, channels, num_frames, height, width = latents.shape
524
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
525
+
526
+ video = []
527
+ for i in range(0, latents.shape[0], decode_chunk_size):
528
+ batch_latents = latents[i : i + decode_chunk_size]
529
+ batch_latents = self.vae.decode(batch_latents).sample
530
+ video.append(batch_latents)
531
+
532
+ video = torch.cat(video)
533
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
534
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
535
+ video = video.float()
536
+ return video
537
+
538
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
539
+ def prepare_extra_step_kwargs(self, generator, eta):
540
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
541
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
542
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
543
+ # and should be between [0, 1]
544
+
545
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
546
+ extra_step_kwargs = {}
547
+ if accepts_eta:
548
+ extra_step_kwargs["eta"] = eta
549
+
550
+ # check if the scheduler accepts generator
551
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
552
+ if accepts_generator:
553
+ extra_step_kwargs["generator"] = generator
554
+ return extra_step_kwargs
555
+
556
+ def check_inputs(
557
+ self,
558
+ prompt,
559
+ strength,
560
+ height,
561
+ width,
562
+ video=None,
563
+ latents=None,
564
+ negative_prompt=None,
565
+ prompt_embeds=None,
566
+ negative_prompt_embeds=None,
567
+ ip_adapter_image=None,
568
+ ip_adapter_image_embeds=None,
569
+ callback_on_step_end_tensor_inputs=None,
570
+ ):
571
+ if strength < 0 or strength > 1:
572
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
573
+
574
+ if height % 8 != 0 or width % 8 != 0:
575
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
576
+
577
+ if callback_on_step_end_tensor_inputs is not None and not all(
578
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
579
+ ):
580
+ raise ValueError(
581
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
582
+ )
583
+
584
+ if prompt is not None and prompt_embeds is not None:
585
+ raise ValueError(
586
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
587
+ " only forward one of the two."
588
+ )
589
+ elif prompt is None and prompt_embeds is None:
590
+ raise ValueError(
591
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
592
+ )
593
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
594
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
595
+
596
+ if negative_prompt is not None and negative_prompt_embeds is not None:
597
+ raise ValueError(
598
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
599
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
600
+ )
601
+
602
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
603
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
604
+ raise ValueError(
605
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
606
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
607
+ f" {negative_prompt_embeds.shape}."
608
+ )
609
+
610
+ if video is not None and latents is not None:
611
+ raise ValueError("Only one of `video` or `latents` should be provided")
612
+
613
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
614
+ raise ValueError(
615
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
616
+ )
617
+
618
+ if ip_adapter_image_embeds is not None:
619
+ if not isinstance(ip_adapter_image_embeds, list):
620
+ raise ValueError(
621
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
622
+ )
623
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
624
+ raise ValueError(
625
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
626
+ )
627
+
628
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
629
+ # get the original timestep using init_timestep
630
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
631
+
632
+ t_start = max(num_inference_steps - init_timestep, 0)
633
+ timesteps = timesteps[t_start * self.scheduler.order :]
634
+
635
+ return timesteps, num_inference_steps - t_start
636
+
637
+ def prepare_latents(
638
+ self,
639
+ video: Optional[torch.Tensor] = None,
640
+ height: int = 64,
641
+ width: int = 64,
642
+ num_channels_latents: int = 4,
643
+ batch_size: int = 1,
644
+ timestep: Optional[int] = None,
645
+ dtype: Optional[torch.dtype] = None,
646
+ device: Optional[torch.device] = None,
647
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
648
+ latents: Optional[torch.Tensor] = None,
649
+ decode_chunk_size: int = 16,
650
+ add_noise: bool = False,
651
+ ) -> torch.Tensor:
652
+ num_frames = video.shape[1] if latents is None else latents.shape[2]
653
+ shape = (
654
+ batch_size,
655
+ num_channels_latents,
656
+ num_frames,
657
+ height // self.vae_scale_factor,
658
+ width // self.vae_scale_factor,
659
+ )
660
+
661
+ if isinstance(generator, list) and len(generator) != batch_size:
662
+ raise ValueError(
663
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
664
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
665
+ )
666
+
667
+ if latents is None:
668
+ # make sure the VAE is in float32 mode, as it overflows in float16
669
+ if self.vae.config.force_upcast:
670
+ video = video.float()
671
+ self.vae.to(dtype=torch.float32)
672
+
673
+ if isinstance(generator, list):
674
+ init_latents = [
675
+ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
676
+ for i in range(batch_size)
677
+ ]
678
+ else:
679
+ init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
680
+
681
+ init_latents = torch.cat(init_latents, dim=0)
682
+
683
+ # restore vae to original dtype
684
+ if self.vae.config.force_upcast:
685
+ self.vae.to(dtype)
686
+
687
+ init_latents = init_latents.to(dtype)
688
+ init_latents = self.vae.config.scaling_factor * init_latents
689
+
690
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
691
+ # expand init_latents for batch_size
692
+ error_message = (
693
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
694
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
695
+ )
696
+ raise ValueError(error_message)
697
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
698
+ raise ValueError(
699
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
700
+ )
701
+ else:
702
+ init_latents = torch.cat([init_latents], dim=0)
703
+
704
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
705
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
706
+ else:
707
+ if shape != latents.shape:
708
+ # [B, C, F, H, W]
709
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
710
+
711
+ latents = latents.to(device, dtype=dtype)
712
+
713
+ if add_noise:
714
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
715
+ latents = self.scheduler.add_noise(latents, noise, timestep)
716
+
717
+ return latents
718
+
719
+ @property
720
+ def guidance_scale(self):
721
+ return self._guidance_scale
722
+
723
+ @property
724
+ def clip_skip(self):
725
+ return self._clip_skip
726
+
727
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
728
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
729
+ # corresponds to doing no classifier free guidance.
730
+ @property
731
+ def do_classifier_free_guidance(self):
732
+ return self._guidance_scale > 1
733
+
734
+ @property
735
+ def cross_attention_kwargs(self):
736
+ return self._cross_attention_kwargs
737
+
738
+ @property
739
+ def num_timesteps(self):
740
+ return self._num_timesteps
741
+
742
+ @property
743
+ def interrupt(self):
744
+ return self._interrupt
745
+
746
+ @torch.no_grad()
747
+ def __call__(
748
+ self,
749
+ video: List[List[PipelineImageInput]] = None,
750
+ prompt: Optional[Union[str, List[str]]] = None,
751
+ height: Optional[int] = None,
752
+ width: Optional[int] = None,
753
+ num_inference_steps: int = 50,
754
+ enforce_inference_steps: bool = False,
755
+ timesteps: Optional[List[int]] = None,
756
+ sigmas: Optional[List[float]] = None,
757
+ guidance_scale: float = 7.5,
758
+ strength: float = 0.8,
759
+ negative_prompt: Optional[Union[str, List[str]]] = None,
760
+ num_videos_per_prompt: Optional[int] = 1,
761
+ eta: float = 0.0,
762
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
763
+ latents: Optional[torch.Tensor] = None,
764
+ prompt_embeds: Optional[torch.Tensor] = None,
765
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
766
+ ip_adapter_image: Optional[PipelineImageInput] = None,
767
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
768
+ output_type: Optional[str] = "pil",
769
+ return_dict: bool = True,
770
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
771
+ clip_skip: Optional[int] = None,
772
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
773
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
774
+ decode_chunk_size: int = 16,
775
+ ):
776
+ r"""
777
+ The call function to the pipeline for generation.
778
+
779
+ Args:
780
+ video (`List[PipelineImageInput]`):
781
+ The input video to condition the generation on. Must be a list of images/frames of the video.
782
+ prompt (`str` or `List[str]`, *optional*):
783
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
784
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
785
+ The height in pixels of the generated video.
786
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
787
+ The width in pixels of the generated video.
788
+ num_inference_steps (`int`, *optional*, defaults to 50):
789
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
790
+ expense of slower inference.
791
+ timesteps (`List[int]`, *optional*):
792
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
793
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
794
+ passed will be used. Must be in descending order.
795
+ sigmas (`List[float]`, *optional*):
796
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
797
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
798
+ will be used.
799
+ strength (`float`, *optional*, defaults to 0.8):
800
+ Higher strength leads to more differences between original video and generated video.
801
+ guidance_scale (`float`, *optional*, defaults to 7.5):
802
+ A higher guidance scale value encourages the model to generate images closely linked to the text
803
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
804
+ negative_prompt (`str` or `List[str]`, *optional*):
805
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
806
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
807
+ eta (`float`, *optional*, defaults to 0.0):
808
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
809
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
810
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
811
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
812
+ generation deterministic.
813
+ latents (`torch.Tensor`, *optional*):
814
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
815
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
816
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
817
+ `(batch_size, num_channel, num_frames, height, width)`.
818
+ prompt_embeds (`torch.Tensor`, *optional*):
819
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
820
+ provided, text embeddings are generated from the `prompt` input argument.
821
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
822
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
823
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
824
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
825
+ Optional image input to work with IP Adapters.
826
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
827
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
828
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
829
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
830
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
831
+ output_type (`str`, *optional*, defaults to `"pil"`):
832
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
833
+ return_dict (`bool`, *optional*, defaults to `True`):
834
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
835
+ cross_attention_kwargs (`dict`, *optional*):
836
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
837
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
838
+ clip_skip (`int`, *optional*):
839
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
840
+ the output of the pre-final layer will be used for computing the prompt embeddings.
841
+ callback_on_step_end (`Callable`, *optional*):
842
+ A function that calls at the end of each denoising steps during the inference. The function is called
843
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
844
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
845
+ `callback_on_step_end_tensor_inputs`.
846
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
847
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
848
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
849
+ `._callback_tensor_inputs` attribute of your pipeline class.
850
+ decode_chunk_size (`int`, defaults to `16`):
851
+ The number of frames to decode at a time when calling `decode_latents` method.
852
+
853
+ Examples:
854
+
855
+ Returns:
856
+ [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
857
+ If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
858
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
859
+ """
860
+
861
+ # 0. Default height and width to unet
862
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
863
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
864
+
865
+ num_videos_per_prompt = 1
866
+
867
+ # 1. Check inputs. Raise error if not correct
868
+ self.check_inputs(
869
+ prompt=prompt,
870
+ strength=strength,
871
+ height=height,
872
+ width=width,
873
+ negative_prompt=negative_prompt,
874
+ prompt_embeds=prompt_embeds,
875
+ negative_prompt_embeds=negative_prompt_embeds,
876
+ video=video,
877
+ latents=latents,
878
+ ip_adapter_image=ip_adapter_image,
879
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
880
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
881
+ )
882
+
883
+ self._guidance_scale = guidance_scale
884
+ self._clip_skip = clip_skip
885
+ self._cross_attention_kwargs = cross_attention_kwargs
886
+ self._interrupt = False
887
+
888
+ # 2. Define call parameters
889
+ if prompt is not None and isinstance(prompt, (str, dict)):
890
+ batch_size = 1
891
+ elif prompt is not None and isinstance(prompt, list):
892
+ batch_size = len(prompt)
893
+ else:
894
+ batch_size = prompt_embeds.shape[0]
895
+
896
+ device = self._execution_device
897
+ dtype = self.dtype
898
+
899
+ # 3. Prepare timesteps
900
+ if not enforce_inference_steps:
901
+ timesteps, num_inference_steps = retrieve_timesteps(
902
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
903
+ )
904
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
905
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
906
+ else:
907
+ denoising_inference_steps = int(num_inference_steps / strength)
908
+ timesteps, denoising_inference_steps = retrieve_timesteps(
909
+ self.scheduler, denoising_inference_steps, device, timesteps, sigmas
910
+ )
911
+ timesteps = timesteps[-num_inference_steps:]
912
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
913
+
914
+ # 4. Prepare latent variables
915
+ if latents is None:
916
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
917
+ # Move the number of frames before the number of channels.
918
+ video = video.permute(0, 2, 1, 3, 4)
919
+ video = video.to(device=device, dtype=dtype)
920
+ num_channels_latents = self.unet.config.in_channels
921
+ latents = self.prepare_latents(
922
+ video=video,
923
+ height=height,
924
+ width=width,
925
+ num_channels_latents=num_channels_latents,
926
+ batch_size=batch_size * num_videos_per_prompt,
927
+ timestep=latent_timestep,
928
+ dtype=dtype,
929
+ device=device,
930
+ generator=generator,
931
+ latents=latents,
932
+ decode_chunk_size=decode_chunk_size,
933
+ add_noise=enforce_inference_steps,
934
+ )
935
+
936
+ # 5. Encode input prompt
937
+ text_encoder_lora_scale = (
938
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
939
+ )
940
+ num_frames = latents.shape[2]
941
+ if self.free_noise_enabled:
942
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
943
+ prompt=prompt,
944
+ num_frames=num_frames,
945
+ device=device,
946
+ num_videos_per_prompt=num_videos_per_prompt,
947
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
948
+ negative_prompt=negative_prompt,
949
+ prompt_embeds=prompt_embeds,
950
+ negative_prompt_embeds=negative_prompt_embeds,
951
+ lora_scale=text_encoder_lora_scale,
952
+ clip_skip=self.clip_skip,
953
+ )
954
+ else:
955
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
956
+ prompt,
957
+ device,
958
+ num_videos_per_prompt,
959
+ self.do_classifier_free_guidance,
960
+ negative_prompt,
961
+ prompt_embeds=prompt_embeds,
962
+ negative_prompt_embeds=negative_prompt_embeds,
963
+ lora_scale=text_encoder_lora_scale,
964
+ clip_skip=self.clip_skip,
965
+ )
966
+
967
+ # For classifier free guidance, we need to do two forward passes.
968
+ # Here we concatenate the unconditional and text embeddings into a single batch
969
+ # to avoid doing two forward passes
970
+ if self.do_classifier_free_guidance:
971
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
972
+
973
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
974
+
975
+ # 6. Prepare IP-Adapter embeddings
976
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
977
+ image_embeds = self.prepare_ip_adapter_image_embeds(
978
+ ip_adapter_image,
979
+ ip_adapter_image_embeds,
980
+ device,
981
+ batch_size * num_videos_per_prompt,
982
+ self.do_classifier_free_guidance,
983
+ )
984
+
985
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
986
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
987
+
988
+ # 8. Add image embeds for IP-Adapter
989
+ added_cond_kwargs = (
990
+ {"image_embeds": image_embeds}
991
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
992
+ else None
993
+ )
994
+
995
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
996
+ for free_init_iter in range(num_free_init_iters):
997
+ if self.free_init_enabled:
998
+ latents, timesteps = self._apply_free_init(
999
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
1000
+ )
1001
+ num_inference_steps = len(timesteps)
1002
+ # make sure to readjust timesteps based on strength
1003
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
1004
+
1005
+ self._num_timesteps = len(timesteps)
1006
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1007
+
1008
+ # 9. Denoising loop
1009
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1010
+ for i, t in enumerate(timesteps):
1011
+ if self.interrupt:
1012
+ continue
1013
+
1014
+ # expand the latents if we are doing classifier free guidance
1015
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1016
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1017
+
1018
+ # predict the noise residual
1019
+ noise_pred = self.unet(
1020
+ latent_model_input,
1021
+ t,
1022
+ encoder_hidden_states=prompt_embeds,
1023
+ cross_attention_kwargs=self.cross_attention_kwargs,
1024
+ added_cond_kwargs=added_cond_kwargs,
1025
+ ).sample
1026
+
1027
+ # perform guidance
1028
+ if self.do_classifier_free_guidance:
1029
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1030
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1031
+
1032
+ # compute the previous noisy sample x_t -> x_t-1
1033
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1034
+
1035
+ if callback_on_step_end is not None:
1036
+ callback_kwargs = {}
1037
+ for k in callback_on_step_end_tensor_inputs:
1038
+ callback_kwargs[k] = locals()[k]
1039
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1040
+
1041
+ latents = callback_outputs.pop("latents", latents)
1042
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1043
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1044
+
1045
+ # call the callback, if provided
1046
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1047
+ progress_bar.update()
1048
+
1049
+ if XLA_AVAILABLE:
1050
+ xm.mark_step()
1051
+
1052
+ # 10. Post-processing
1053
+ if output_type == "latent":
1054
+ video = latents
1055
+ else:
1056
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
1057
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1058
+
1059
+ # 11. Offload all models
1060
+ self.maybe_free_model_hooks()
1061
+
1062
+ if not return_dict:
1063
+ return (video,)
1064
+
1065
+ return AnimateDiffPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
+
22
+ from ...image_processor import PipelineImageInput
23
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
24
+ from ...models import (
25
+ AutoencoderKL,
26
+ ControlNetModel,
27
+ ImageProjection,
28
+ MultiControlNetModel,
29
+ UNet2DConditionModel,
30
+ UNetMotionModel,
31
+ )
32
+ from ...models.lora import adjust_lora_scale_text_encoder
33
+ from ...models.unets.unet_motion_model import MotionAdapter
34
+ from ...schedulers import (
35
+ DDIMScheduler,
36
+ DPMSolverMultistepScheduler,
37
+ EulerAncestralDiscreteScheduler,
38
+ EulerDiscreteScheduler,
39
+ LMSDiscreteScheduler,
40
+ PNDMScheduler,
41
+ )
42
+ from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
43
+ from ...utils.torch_utils import is_compiled_module, randn_tensor
44
+ from ...video_processor import VideoProcessor
45
+ from ..free_init_utils import FreeInitMixin
46
+ from ..free_noise_utils import AnimateDiffFreeNoiseMixin
47
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
48
+ from .pipeline_output import AnimateDiffPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from PIL import Image
66
+ >>> from tqdm.auto import tqdm
67
+
68
+ >>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline
69
+ >>> from diffusers.utils import export_to_gif, load_video
70
+ >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
71
+
72
+ >>> controlnet = ControlNetModel.from_pretrained(
73
+ ... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
74
+ ... )
75
+ >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
76
+ >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
77
+
78
+ >>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
79
+ ... "SG161222/Realistic_Vision_V5.1_noVAE",
80
+ ... motion_adapter=motion_adapter,
81
+ ... controlnet=controlnet,
82
+ ... vae=vae,
83
+ ... ).to(device="cuda", dtype=torch.float16)
84
+
85
+ >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
86
+ >>> pipe.load_lora_weights(
87
+ ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora"
88
+ ... )
89
+ >>> pipe.set_adapters(["lcm-lora"], [0.8])
90
+
91
+ >>> video = load_video(
92
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif"
93
+ ... )
94
+ >>> video = [frame.convert("RGB") for frame in video]
95
+
96
+ >>> from controlnet_aux.processor import OpenposeDetector
97
+
98
+ >>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
99
+ >>> for frame in tqdm(video):
100
+ ... conditioning_frames.append(open_pose(frame))
101
+
102
+ >>> prompt = "astronaut in space, dancing"
103
+ >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
104
+
105
+ >>> strength = 0.8
106
+ >>> with torch.inference_mode():
107
+ ... video = pipe(
108
+ ... video=video,
109
+ ... prompt=prompt,
110
+ ... negative_prompt=negative_prompt,
111
+ ... num_inference_steps=10,
112
+ ... guidance_scale=2.0,
113
+ ... controlnet_conditioning_scale=0.75,
114
+ ... conditioning_frames=conditioning_frames,
115
+ ... strength=strength,
116
+ ... generator=torch.Generator().manual_seed(42),
117
+ ... ).frames[0]
118
+
119
+ >>> video = [frame.resize(conditioning_frames[0].size) for frame in video]
120
+ >>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
121
+ ```
122
+ """
123
+
124
+
125
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
126
+ def retrieve_latents(
127
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
128
+ ):
129
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
130
+ return encoder_output.latent_dist.sample(generator)
131
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
132
+ return encoder_output.latent_dist.mode()
133
+ elif hasattr(encoder_output, "latents"):
134
+ return encoder_output.latents
135
+ else:
136
+ raise AttributeError("Could not access latents of provided encoder_output")
137
+
138
+
139
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
140
+ def retrieve_timesteps(
141
+ scheduler,
142
+ num_inference_steps: Optional[int] = None,
143
+ device: Optional[Union[str, torch.device]] = None,
144
+ timesteps: Optional[List[int]] = None,
145
+ sigmas: Optional[List[float]] = None,
146
+ **kwargs,
147
+ ):
148
+ r"""
149
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
150
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
151
+
152
+ Args:
153
+ scheduler (`SchedulerMixin`):
154
+ The scheduler to get timesteps from.
155
+ num_inference_steps (`int`):
156
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
157
+ must be `None`.
158
+ device (`str` or `torch.device`, *optional*):
159
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
160
+ timesteps (`List[int]`, *optional*):
161
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
162
+ `num_inference_steps` and `sigmas` must be `None`.
163
+ sigmas (`List[float]`, *optional*):
164
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
165
+ `num_inference_steps` and `timesteps` must be `None`.
166
+
167
+ Returns:
168
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
169
+ second element is the number of inference steps.
170
+ """
171
+ if timesteps is not None and sigmas is not None:
172
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
173
+ if timesteps is not None:
174
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
175
+ if not accepts_timesteps:
176
+ raise ValueError(
177
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
178
+ f" timestep schedules. Please check whether you are using the correct scheduler."
179
+ )
180
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
181
+ timesteps = scheduler.timesteps
182
+ num_inference_steps = len(timesteps)
183
+ elif sigmas is not None:
184
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
185
+ if not accept_sigmas:
186
+ raise ValueError(
187
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
188
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
189
+ )
190
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
191
+ timesteps = scheduler.timesteps
192
+ num_inference_steps = len(timesteps)
193
+ else:
194
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
195
+ timesteps = scheduler.timesteps
196
+ return timesteps, num_inference_steps
197
+
198
+
199
+ class AnimateDiffVideoToVideoControlNetPipeline(
200
+ DiffusionPipeline,
201
+ StableDiffusionMixin,
202
+ TextualInversionLoaderMixin,
203
+ IPAdapterMixin,
204
+ StableDiffusionLoraLoaderMixin,
205
+ FreeInitMixin,
206
+ AnimateDiffFreeNoiseMixin,
207
+ FromSingleFileMixin,
208
+ ):
209
+ r"""
210
+ Pipeline for video-to-video generation with ControlNet guidance.
211
+
212
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
213
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
214
+
215
+ The pipeline also inherits the following loading methods:
216
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
217
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
218
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
219
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
220
+
221
+ Args:
222
+ vae ([`AutoencoderKL`]):
223
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
224
+ text_encoder ([`CLIPTextModel`]):
225
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
226
+ tokenizer (`CLIPTokenizer`):
227
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
228
+ unet ([`UNet2DConditionModel`]):
229
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
230
+ motion_adapter ([`MotionAdapter`]):
231
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
232
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]` or `Tuple[ControlNetModel]` or `MultiControlNetModel`):
233
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
234
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
235
+ additional conditioning.
236
+ scheduler ([`SchedulerMixin`]):
237
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
238
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
239
+ """
240
+
241
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
242
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
243
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
244
+
245
+ def __init__(
246
+ self,
247
+ vae: AutoencoderKL,
248
+ text_encoder: CLIPTextModel,
249
+ tokenizer: CLIPTokenizer,
250
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
251
+ motion_adapter: MotionAdapter,
252
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
253
+ scheduler: Union[
254
+ DDIMScheduler,
255
+ PNDMScheduler,
256
+ LMSDiscreteScheduler,
257
+ EulerDiscreteScheduler,
258
+ EulerAncestralDiscreteScheduler,
259
+ DPMSolverMultistepScheduler,
260
+ ],
261
+ feature_extractor: CLIPImageProcessor = None,
262
+ image_encoder: CLIPVisionModelWithProjection = None,
263
+ ):
264
+ super().__init__()
265
+ if isinstance(unet, UNet2DConditionModel):
266
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
267
+
268
+ if isinstance(controlnet, (list, tuple)):
269
+ controlnet = MultiControlNetModel(controlnet)
270
+
271
+ self.register_modules(
272
+ vae=vae,
273
+ text_encoder=text_encoder,
274
+ tokenizer=tokenizer,
275
+ unet=unet,
276
+ motion_adapter=motion_adapter,
277
+ controlnet=controlnet,
278
+ scheduler=scheduler,
279
+ feature_extractor=feature_extractor,
280
+ image_encoder=image_encoder,
281
+ )
282
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
283
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
284
+ self.control_video_processor = VideoProcessor(
285
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
286
+ )
287
+
288
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt
289
+ def encode_prompt(
290
+ self,
291
+ prompt,
292
+ device,
293
+ num_images_per_prompt,
294
+ do_classifier_free_guidance,
295
+ negative_prompt=None,
296
+ prompt_embeds: Optional[torch.Tensor] = None,
297
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
298
+ lora_scale: Optional[float] = None,
299
+ clip_skip: Optional[int] = None,
300
+ ):
301
+ r"""
302
+ Encodes the prompt into text encoder hidden states.
303
+
304
+ Args:
305
+ prompt (`str` or `List[str]`, *optional*):
306
+ prompt to be encoded
307
+ device: (`torch.device`):
308
+ torch device
309
+ num_images_per_prompt (`int`):
310
+ number of images that should be generated per prompt
311
+ do_classifier_free_guidance (`bool`):
312
+ whether to use classifier free guidance or not
313
+ negative_prompt (`str` or `List[str]`, *optional*):
314
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
315
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
316
+ less than `1`).
317
+ prompt_embeds (`torch.Tensor`, *optional*):
318
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
319
+ provided, text embeddings will be generated from `prompt` input argument.
320
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
321
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
322
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
323
+ argument.
324
+ lora_scale (`float`, *optional*):
325
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
326
+ clip_skip (`int`, *optional*):
327
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
328
+ the output of the pre-final layer will be used for computing the prompt embeddings.
329
+ """
330
+ # set lora scale so that monkey patched LoRA
331
+ # function of text encoder can correctly access it
332
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
333
+ self._lora_scale = lora_scale
334
+
335
+ # dynamically adjust the LoRA scale
336
+ if not USE_PEFT_BACKEND:
337
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
338
+ else:
339
+ scale_lora_layers(self.text_encoder, lora_scale)
340
+
341
+ if prompt is not None and isinstance(prompt, (str, dict)):
342
+ batch_size = 1
343
+ elif prompt is not None and isinstance(prompt, list):
344
+ batch_size = len(prompt)
345
+ else:
346
+ batch_size = prompt_embeds.shape[0]
347
+
348
+ if prompt_embeds is None:
349
+ # textual inversion: process multi-vector tokens if necessary
350
+ if isinstance(self, TextualInversionLoaderMixin):
351
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
352
+
353
+ text_inputs = self.tokenizer(
354
+ prompt,
355
+ padding="max_length",
356
+ max_length=self.tokenizer.model_max_length,
357
+ truncation=True,
358
+ return_tensors="pt",
359
+ )
360
+ text_input_ids = text_inputs.input_ids
361
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
362
+
363
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
364
+ text_input_ids, untruncated_ids
365
+ ):
366
+ removed_text = self.tokenizer.batch_decode(
367
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
368
+ )
369
+ logger.warning(
370
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
371
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
372
+ )
373
+
374
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
375
+ attention_mask = text_inputs.attention_mask.to(device)
376
+ else:
377
+ attention_mask = None
378
+
379
+ if clip_skip is None:
380
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
381
+ prompt_embeds = prompt_embeds[0]
382
+ else:
383
+ prompt_embeds = self.text_encoder(
384
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
385
+ )
386
+ # Access the `hidden_states` first, that contains a tuple of
387
+ # all the hidden states from the encoder layers. Then index into
388
+ # the tuple to access the hidden states from the desired layer.
389
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
390
+ # We also need to apply the final LayerNorm here to not mess with the
391
+ # representations. The `last_hidden_states` that we typically use for
392
+ # obtaining the final prompt representations passes through the LayerNorm
393
+ # layer.
394
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
395
+
396
+ if self.text_encoder is not None:
397
+ prompt_embeds_dtype = self.text_encoder.dtype
398
+ elif self.unet is not None:
399
+ prompt_embeds_dtype = self.unet.dtype
400
+ else:
401
+ prompt_embeds_dtype = prompt_embeds.dtype
402
+
403
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
404
+
405
+ bs_embed, seq_len, _ = prompt_embeds.shape
406
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
407
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
408
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
409
+
410
+ # get unconditional embeddings for classifier free guidance
411
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
412
+ uncond_tokens: List[str]
413
+ if negative_prompt is None:
414
+ uncond_tokens = [""] * batch_size
415
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
416
+ raise TypeError(
417
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
418
+ f" {type(prompt)}."
419
+ )
420
+ elif isinstance(negative_prompt, str):
421
+ uncond_tokens = [negative_prompt]
422
+ elif batch_size != len(negative_prompt):
423
+ raise ValueError(
424
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
425
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
426
+ " the batch size of `prompt`."
427
+ )
428
+ else:
429
+ uncond_tokens = negative_prompt
430
+
431
+ # textual inversion: process multi-vector tokens if necessary
432
+ if isinstance(self, TextualInversionLoaderMixin):
433
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
434
+
435
+ max_length = prompt_embeds.shape[1]
436
+ uncond_input = self.tokenizer(
437
+ uncond_tokens,
438
+ padding="max_length",
439
+ max_length=max_length,
440
+ truncation=True,
441
+ return_tensors="pt",
442
+ )
443
+
444
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
445
+ attention_mask = uncond_input.attention_mask.to(device)
446
+ else:
447
+ attention_mask = None
448
+
449
+ negative_prompt_embeds = self.text_encoder(
450
+ uncond_input.input_ids.to(device),
451
+ attention_mask=attention_mask,
452
+ )
453
+ negative_prompt_embeds = negative_prompt_embeds[0]
454
+
455
+ if do_classifier_free_guidance:
456
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
457
+ seq_len = negative_prompt_embeds.shape[1]
458
+
459
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
460
+
461
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
462
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
463
+
464
+ if self.text_encoder is not None:
465
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
466
+ # Retrieve the original scale by scaling back the LoRA layers
467
+ unscale_lora_layers(self.text_encoder, lora_scale)
468
+
469
+ return prompt_embeds, negative_prompt_embeds
470
+
471
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
472
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
473
+ dtype = next(self.image_encoder.parameters()).dtype
474
+
475
+ if not isinstance(image, torch.Tensor):
476
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
477
+
478
+ image = image.to(device=device, dtype=dtype)
479
+ if output_hidden_states:
480
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
481
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
482
+ uncond_image_enc_hidden_states = self.image_encoder(
483
+ torch.zeros_like(image), output_hidden_states=True
484
+ ).hidden_states[-2]
485
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
486
+ num_images_per_prompt, dim=0
487
+ )
488
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
489
+ else:
490
+ image_embeds = self.image_encoder(image).image_embeds
491
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
492
+ uncond_image_embeds = torch.zeros_like(image_embeds)
493
+
494
+ return image_embeds, uncond_image_embeds
495
+
496
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
497
+ def prepare_ip_adapter_image_embeds(
498
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
499
+ ):
500
+ image_embeds = []
501
+ if do_classifier_free_guidance:
502
+ negative_image_embeds = []
503
+ if ip_adapter_image_embeds is None:
504
+ if not isinstance(ip_adapter_image, list):
505
+ ip_adapter_image = [ip_adapter_image]
506
+
507
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
508
+ raise ValueError(
509
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
510
+ )
511
+
512
+ for single_ip_adapter_image, image_proj_layer in zip(
513
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
514
+ ):
515
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
516
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
517
+ single_ip_adapter_image, device, 1, output_hidden_state
518
+ )
519
+
520
+ image_embeds.append(single_image_embeds[None, :])
521
+ if do_classifier_free_guidance:
522
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
523
+ else:
524
+ for single_image_embeds in ip_adapter_image_embeds:
525
+ if do_classifier_free_guidance:
526
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
527
+ negative_image_embeds.append(single_negative_image_embeds)
528
+ image_embeds.append(single_image_embeds)
529
+
530
+ ip_adapter_image_embeds = []
531
+ for i, single_image_embeds in enumerate(image_embeds):
532
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
533
+ if do_classifier_free_guidance:
534
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
535
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
536
+
537
+ single_image_embeds = single_image_embeds.to(device=device)
538
+ ip_adapter_image_embeds.append(single_image_embeds)
539
+
540
+ return ip_adapter_image_embeds
541
+
542
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video
543
+ def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
544
+ latents = []
545
+ for i in range(0, len(video), decode_chunk_size):
546
+ batch_video = video[i : i + decode_chunk_size]
547
+ batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
548
+ latents.append(batch_video)
549
+ return torch.cat(latents)
550
+
551
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
552
+ def decode_latents(self, latents, decode_chunk_size: int = 16):
553
+ latents = 1 / self.vae.config.scaling_factor * latents
554
+
555
+ batch_size, channels, num_frames, height, width = latents.shape
556
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
557
+
558
+ video = []
559
+ for i in range(0, latents.shape[0], decode_chunk_size):
560
+ batch_latents = latents[i : i + decode_chunk_size]
561
+ batch_latents = self.vae.decode(batch_latents).sample
562
+ video.append(batch_latents)
563
+
564
+ video = torch.cat(video)
565
+ video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
566
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
567
+ video = video.float()
568
+ return video
569
+
570
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
571
+ def prepare_extra_step_kwargs(self, generator, eta):
572
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
573
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
574
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
575
+ # and should be between [0, 1]
576
+
577
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
578
+ extra_step_kwargs = {}
579
+ if accepts_eta:
580
+ extra_step_kwargs["eta"] = eta
581
+
582
+ # check if the scheduler accepts generator
583
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
584
+ if accepts_generator:
585
+ extra_step_kwargs["generator"] = generator
586
+ return extra_step_kwargs
587
+
588
+ def check_inputs(
589
+ self,
590
+ prompt,
591
+ strength,
592
+ height,
593
+ width,
594
+ video=None,
595
+ conditioning_frames=None,
596
+ latents=None,
597
+ negative_prompt=None,
598
+ prompt_embeds=None,
599
+ negative_prompt_embeds=None,
600
+ ip_adapter_image=None,
601
+ ip_adapter_image_embeds=None,
602
+ callback_on_step_end_tensor_inputs=None,
603
+ controlnet_conditioning_scale=1.0,
604
+ control_guidance_start=0.0,
605
+ control_guidance_end=1.0,
606
+ ):
607
+ if strength < 0 or strength > 1:
608
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
609
+
610
+ if height % 8 != 0 or width % 8 != 0:
611
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
612
+
613
+ if callback_on_step_end_tensor_inputs is not None and not all(
614
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
615
+ ):
616
+ raise ValueError(
617
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
618
+ )
619
+
620
+ if prompt is not None and prompt_embeds is not None:
621
+ raise ValueError(
622
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
623
+ " only forward one of the two."
624
+ )
625
+ elif prompt is None and prompt_embeds is None:
626
+ raise ValueError(
627
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
628
+ )
629
+ elif prompt is not None and not isinstance(prompt, (str, list, dict)):
630
+ raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
631
+
632
+ if negative_prompt is not None and negative_prompt_embeds is not None:
633
+ raise ValueError(
634
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
635
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
636
+ )
637
+
638
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
639
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
640
+ raise ValueError(
641
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
642
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
643
+ f" {negative_prompt_embeds.shape}."
644
+ )
645
+
646
+ if video is not None and latents is not None:
647
+ raise ValueError("Only one of `video` or `latents` should be provided")
648
+
649
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
650
+ raise ValueError(
651
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
652
+ )
653
+
654
+ if ip_adapter_image_embeds is not None:
655
+ if not isinstance(ip_adapter_image_embeds, list):
656
+ raise ValueError(
657
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
658
+ )
659
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
660
+ raise ValueError(
661
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
662
+ )
663
+
664
+ if isinstance(self.controlnet, MultiControlNetModel):
665
+ if isinstance(prompt, list):
666
+ logger.warning(
667
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
668
+ " prompts. The conditionings will be fixed across the prompts."
669
+ )
670
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
671
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
672
+ )
673
+
674
+ num_frames = len(video) if latents is None else latents.shape[2]
675
+
676
+ if (
677
+ isinstance(self.controlnet, ControlNetModel)
678
+ or is_compiled
679
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
680
+ ):
681
+ if not isinstance(conditioning_frames, list):
682
+ raise TypeError(
683
+ f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}"
684
+ )
685
+ if len(conditioning_frames) != num_frames:
686
+ raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}")
687
+ elif (
688
+ isinstance(self.controlnet, MultiControlNetModel)
689
+ or is_compiled
690
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
691
+ ):
692
+ if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list):
693
+ raise TypeError(
694
+ f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}"
695
+ )
696
+ if len(conditioning_frames[0]) != num_frames:
697
+ raise ValueError(
698
+ f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}"
699
+ )
700
+ if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames):
701
+ raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
702
+ else:
703
+ assert False
704
+
705
+ # Check `controlnet_conditioning_scale`
706
+ if (
707
+ isinstance(self.controlnet, ControlNetModel)
708
+ or is_compiled
709
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
710
+ ):
711
+ if not isinstance(controlnet_conditioning_scale, float):
712
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
713
+ elif (
714
+ isinstance(self.controlnet, MultiControlNetModel)
715
+ or is_compiled
716
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
717
+ ):
718
+ if isinstance(controlnet_conditioning_scale, list):
719
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
720
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
721
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
722
+ self.controlnet.nets
723
+ ):
724
+ raise ValueError(
725
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
726
+ " the same length as the number of controlnets"
727
+ )
728
+ else:
729
+ assert False
730
+
731
+ if not isinstance(control_guidance_start, (tuple, list)):
732
+ control_guidance_start = [control_guidance_start]
733
+
734
+ if not isinstance(control_guidance_end, (tuple, list)):
735
+ control_guidance_end = [control_guidance_end]
736
+
737
+ if len(control_guidance_start) != len(control_guidance_end):
738
+ raise ValueError(
739
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
740
+ )
741
+
742
+ if isinstance(self.controlnet, MultiControlNetModel):
743
+ if len(control_guidance_start) != len(self.controlnet.nets):
744
+ raise ValueError(
745
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
746
+ )
747
+
748
+ for start, end in zip(control_guidance_start, control_guidance_end):
749
+ if start >= end:
750
+ raise ValueError(
751
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
752
+ )
753
+ if start < 0.0:
754
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
755
+ if end > 1.0:
756
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
757
+
758
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
759
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
760
+ # get the original timestep using init_timestep
761
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
762
+
763
+ t_start = max(num_inference_steps - init_timestep, 0)
764
+ timesteps = timesteps[t_start * self.scheduler.order :]
765
+
766
+ return timesteps, num_inference_steps - t_start
767
+
768
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents
769
+ def prepare_latents(
770
+ self,
771
+ video: Optional[torch.Tensor] = None,
772
+ height: int = 64,
773
+ width: int = 64,
774
+ num_channels_latents: int = 4,
775
+ batch_size: int = 1,
776
+ timestep: Optional[int] = None,
777
+ dtype: Optional[torch.dtype] = None,
778
+ device: Optional[torch.device] = None,
779
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
780
+ latents: Optional[torch.Tensor] = None,
781
+ decode_chunk_size: int = 16,
782
+ add_noise: bool = False,
783
+ ) -> torch.Tensor:
784
+ num_frames = video.shape[1] if latents is None else latents.shape[2]
785
+ shape = (
786
+ batch_size,
787
+ num_channels_latents,
788
+ num_frames,
789
+ height // self.vae_scale_factor,
790
+ width // self.vae_scale_factor,
791
+ )
792
+
793
+ if isinstance(generator, list) and len(generator) != batch_size:
794
+ raise ValueError(
795
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
796
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
797
+ )
798
+
799
+ if latents is None:
800
+ # make sure the VAE is in float32 mode, as it overflows in float16
801
+ if self.vae.config.force_upcast:
802
+ video = video.float()
803
+ self.vae.to(dtype=torch.float32)
804
+
805
+ if isinstance(generator, list):
806
+ init_latents = [
807
+ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
808
+ for i in range(batch_size)
809
+ ]
810
+ else:
811
+ init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
812
+
813
+ init_latents = torch.cat(init_latents, dim=0)
814
+
815
+ # restore vae to original dtype
816
+ if self.vae.config.force_upcast:
817
+ self.vae.to(dtype)
818
+
819
+ init_latents = init_latents.to(dtype)
820
+ init_latents = self.vae.config.scaling_factor * init_latents
821
+
822
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
823
+ # expand init_latents for batch_size
824
+ error_message = (
825
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
826
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
827
+ )
828
+ raise ValueError(error_message)
829
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
830
+ raise ValueError(
831
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
832
+ )
833
+ else:
834
+ init_latents = torch.cat([init_latents], dim=0)
835
+
836
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
837
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
838
+ else:
839
+ if shape != latents.shape:
840
+ # [B, C, F, H, W]
841
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
842
+
843
+ latents = latents.to(device, dtype=dtype)
844
+
845
+ if add_noise:
846
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
847
+ latents = self.scheduler.add_noise(latents, noise, timestep)
848
+
849
+ return latents
850
+
851
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video
852
+ def prepare_conditioning_frames(
853
+ self,
854
+ video,
855
+ width,
856
+ height,
857
+ batch_size,
858
+ num_videos_per_prompt,
859
+ device,
860
+ dtype,
861
+ do_classifier_free_guidance=False,
862
+ guess_mode=False,
863
+ ):
864
+ video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
865
+ dtype=torch.float32
866
+ )
867
+ video = video.permute(0, 2, 1, 3, 4).flatten(0, 1)
868
+ video_batch_size = video.shape[0]
869
+
870
+ if video_batch_size == 1:
871
+ repeat_by = batch_size
872
+ else:
873
+ # image batch size is the same as prompt batch size
874
+ repeat_by = num_videos_per_prompt
875
+
876
+ video = video.repeat_interleave(repeat_by, dim=0)
877
+ video = video.to(device=device, dtype=dtype)
878
+
879
+ if do_classifier_free_guidance and not guess_mode:
880
+ video = torch.cat([video] * 2)
881
+
882
+ return video
883
+
884
+ @property
885
+ def guidance_scale(self):
886
+ return self._guidance_scale
887
+
888
+ @property
889
+ def clip_skip(self):
890
+ return self._clip_skip
891
+
892
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
893
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
894
+ # corresponds to doing no classifier free guidance.
895
+ @property
896
+ def do_classifier_free_guidance(self):
897
+ return self._guidance_scale > 1
898
+
899
+ @property
900
+ def cross_attention_kwargs(self):
901
+ return self._cross_attention_kwargs
902
+
903
+ @property
904
+ def num_timesteps(self):
905
+ return self._num_timesteps
906
+
907
+ @property
908
+ def interrupt(self):
909
+ return self._interrupt
910
+
911
+ @torch.no_grad()
912
+ def __call__(
913
+ self,
914
+ video: List[List[PipelineImageInput]] = None,
915
+ prompt: Optional[Union[str, List[str]]] = None,
916
+ height: Optional[int] = None,
917
+ width: Optional[int] = None,
918
+ num_inference_steps: int = 50,
919
+ enforce_inference_steps: bool = False,
920
+ timesteps: Optional[List[int]] = None,
921
+ sigmas: Optional[List[float]] = None,
922
+ guidance_scale: float = 7.5,
923
+ strength: float = 0.8,
924
+ negative_prompt: Optional[Union[str, List[str]]] = None,
925
+ num_videos_per_prompt: Optional[int] = 1,
926
+ eta: float = 0.0,
927
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
928
+ latents: Optional[torch.Tensor] = None,
929
+ prompt_embeds: Optional[torch.Tensor] = None,
930
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
931
+ ip_adapter_image: Optional[PipelineImageInput] = None,
932
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
933
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
934
+ output_type: Optional[str] = "pil",
935
+ return_dict: bool = True,
936
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
937
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
938
+ guess_mode: bool = False,
939
+ control_guidance_start: Union[float, List[float]] = 0.0,
940
+ control_guidance_end: Union[float, List[float]] = 1.0,
941
+ clip_skip: Optional[int] = None,
942
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
943
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
944
+ decode_chunk_size: int = 16,
945
+ ):
946
+ r"""
947
+ The call function to the pipeline for generation.
948
+
949
+ Args:
950
+ video (`List[PipelineImageInput]`):
951
+ The input video to condition the generation on. Must be a list of images/frames of the video.
952
+ prompt (`str` or `List[str]`, *optional*):
953
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
954
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
955
+ The height in pixels of the generated video.
956
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
957
+ The width in pixels of the generated video.
958
+ num_inference_steps (`int`, *optional*, defaults to 50):
959
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
960
+ expense of slower inference.
961
+ timesteps (`List[int]`, *optional*):
962
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
963
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
964
+ passed will be used. Must be in descending order.
965
+ sigmas (`List[float]`, *optional*):
966
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
967
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
968
+ will be used.
969
+ strength (`float`, *optional*, defaults to 0.8):
970
+ Higher strength leads to more differences between original video and generated video.
971
+ guidance_scale (`float`, *optional*, defaults to 7.5):
972
+ A higher guidance scale value encourages the model to generate images closely linked to the text
973
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
974
+ negative_prompt (`str` or `List[str]`, *optional*):
975
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
976
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
977
+ eta (`float`, *optional*, defaults to 0.0):
978
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
979
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
980
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
981
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
982
+ generation deterministic.
983
+ latents (`torch.Tensor`, *optional*):
984
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
985
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
986
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
987
+ `(batch_size, num_channel, num_frames, height, width)`.
988
+ prompt_embeds (`torch.Tensor`, *optional*):
989
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
990
+ provided, text embeddings are generated from the `prompt` input argument.
991
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
992
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
993
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
994
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
995
+ Optional image input to work with IP Adapters.
996
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
997
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
998
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
999
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1000
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1001
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
1002
+ The ControlNet input condition to provide guidance to the `unet` for generation. If multiple
1003
+ ControlNets are specified, images must be passed as a list such that each element of the list can be
1004
+ correctly batched for input to a single ControlNet.
1005
+ output_type (`str`, *optional*, defaults to `"pil"`):
1006
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
1007
+ return_dict (`bool`, *optional*, defaults to `True`):
1008
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
1009
+ cross_attention_kwargs (`dict`, *optional*):
1010
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1011
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1012
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1013
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1014
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1015
+ the corresponding scale as a list.
1016
+ guess_mode (`bool`, *optional*, defaults to `False`):
1017
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1018
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1019
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1020
+ The percentage of total steps at which the ControlNet starts applying.
1021
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1022
+ The percentage of total steps at which the ControlNet stops applying.
1023
+ clip_skip (`int`, *optional*):
1024
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1025
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1026
+ callback_on_step_end (`Callable`, *optional*):
1027
+ A function that calls at the end of each denoising steps during the inference. The function is called
1028
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1029
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1030
+ `callback_on_step_end_tensor_inputs`.
1031
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1032
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1033
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1034
+ `._callback_tensor_inputs` attribute of your pipeline class.
1035
+ decode_chunk_size (`int`, defaults to `16`):
1036
+ The number of frames to decode at a time when calling `decode_latents` method.
1037
+
1038
+ Examples:
1039
+
1040
+ Returns:
1041
+ [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
1042
+ If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
1043
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1044
+ """
1045
+
1046
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1047
+
1048
+ # align format for control guidance
1049
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1050
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1051
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1052
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1053
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1054
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1055
+ control_guidance_start, control_guidance_end = (
1056
+ mult * [control_guidance_start],
1057
+ mult * [control_guidance_end],
1058
+ )
1059
+
1060
+ # 0. Default height and width to unet
1061
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1062
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1063
+
1064
+ num_videos_per_prompt = 1
1065
+
1066
+ # 1. Check inputs. Raise error if not correct
1067
+ self.check_inputs(
1068
+ prompt=prompt,
1069
+ strength=strength,
1070
+ height=height,
1071
+ width=width,
1072
+ negative_prompt=negative_prompt,
1073
+ prompt_embeds=prompt_embeds,
1074
+ negative_prompt_embeds=negative_prompt_embeds,
1075
+ video=video,
1076
+ conditioning_frames=conditioning_frames,
1077
+ latents=latents,
1078
+ ip_adapter_image=ip_adapter_image,
1079
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
1080
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1081
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
1082
+ control_guidance_start=control_guidance_start,
1083
+ control_guidance_end=control_guidance_end,
1084
+ )
1085
+
1086
+ self._guidance_scale = guidance_scale
1087
+ self._clip_skip = clip_skip
1088
+ self._cross_attention_kwargs = cross_attention_kwargs
1089
+ self._interrupt = False
1090
+
1091
+ # 2. Define call parameters
1092
+ if prompt is not None and isinstance(prompt, (str, dict)):
1093
+ batch_size = 1
1094
+ elif prompt is not None and isinstance(prompt, list):
1095
+ batch_size = len(prompt)
1096
+ else:
1097
+ batch_size = prompt_embeds.shape[0]
1098
+
1099
+ device = self._execution_device
1100
+ dtype = self.dtype
1101
+
1102
+ # 3. Prepare timesteps
1103
+ if not enforce_inference_steps:
1104
+ timesteps, num_inference_steps = retrieve_timesteps(
1105
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1106
+ )
1107
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
1108
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
1109
+ else:
1110
+ denoising_inference_steps = int(num_inference_steps / strength)
1111
+ timesteps, denoising_inference_steps = retrieve_timesteps(
1112
+ self.scheduler, denoising_inference_steps, device, timesteps, sigmas
1113
+ )
1114
+ timesteps = timesteps[-num_inference_steps:]
1115
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
1116
+
1117
+ # 4. Prepare latent variables
1118
+ if latents is None:
1119
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
1120
+ # Move the number of frames before the number of channels.
1121
+ video = video.permute(0, 2, 1, 3, 4)
1122
+ video = video.to(device=device, dtype=dtype)
1123
+
1124
+ num_channels_latents = self.unet.config.in_channels
1125
+ latents = self.prepare_latents(
1126
+ video=video,
1127
+ height=height,
1128
+ width=width,
1129
+ num_channels_latents=num_channels_latents,
1130
+ batch_size=batch_size * num_videos_per_prompt,
1131
+ timestep=latent_timestep,
1132
+ dtype=dtype,
1133
+ device=device,
1134
+ generator=generator,
1135
+ latents=latents,
1136
+ decode_chunk_size=decode_chunk_size,
1137
+ add_noise=enforce_inference_steps,
1138
+ )
1139
+
1140
+ # 5. Encode input prompt
1141
+ text_encoder_lora_scale = (
1142
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1143
+ )
1144
+ num_frames = latents.shape[2]
1145
+ if self.free_noise_enabled:
1146
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
1147
+ prompt=prompt,
1148
+ num_frames=num_frames,
1149
+ device=device,
1150
+ num_videos_per_prompt=num_videos_per_prompt,
1151
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1152
+ negative_prompt=negative_prompt,
1153
+ prompt_embeds=prompt_embeds,
1154
+ negative_prompt_embeds=negative_prompt_embeds,
1155
+ lora_scale=text_encoder_lora_scale,
1156
+ clip_skip=self.clip_skip,
1157
+ )
1158
+ else:
1159
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1160
+ prompt,
1161
+ device,
1162
+ num_videos_per_prompt,
1163
+ self.do_classifier_free_guidance,
1164
+ negative_prompt,
1165
+ prompt_embeds=prompt_embeds,
1166
+ negative_prompt_embeds=negative_prompt_embeds,
1167
+ lora_scale=text_encoder_lora_scale,
1168
+ clip_skip=self.clip_skip,
1169
+ )
1170
+
1171
+ # For classifier free guidance, we need to do two forward passes.
1172
+ # Here we concatenate the unconditional and text embeddings into a single batch
1173
+ # to avoid doing two forward passes
1174
+ if self.do_classifier_free_guidance:
1175
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1176
+
1177
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
1178
+
1179
+ # 6. Prepare IP-Adapter embeddings
1180
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1181
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1182
+ ip_adapter_image,
1183
+ ip_adapter_image_embeds,
1184
+ device,
1185
+ batch_size * num_videos_per_prompt,
1186
+ self.do_classifier_free_guidance,
1187
+ )
1188
+
1189
+ # 7. Prepare ControlNet conditions
1190
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1191
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1192
+
1193
+ global_pool_conditions = (
1194
+ controlnet.config.global_pool_conditions
1195
+ if isinstance(controlnet, ControlNetModel)
1196
+ else controlnet.nets[0].config.global_pool_conditions
1197
+ )
1198
+ guess_mode = guess_mode or global_pool_conditions
1199
+
1200
+ controlnet_keep = []
1201
+ for i in range(len(timesteps)):
1202
+ keeps = [
1203
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1204
+ for s, e in zip(control_guidance_start, control_guidance_end)
1205
+ ]
1206
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1207
+
1208
+ if isinstance(controlnet, ControlNetModel):
1209
+ conditioning_frames = self.prepare_conditioning_frames(
1210
+ video=conditioning_frames,
1211
+ width=width,
1212
+ height=height,
1213
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1214
+ num_videos_per_prompt=num_videos_per_prompt,
1215
+ device=device,
1216
+ dtype=controlnet.dtype,
1217
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1218
+ guess_mode=guess_mode,
1219
+ )
1220
+ elif isinstance(controlnet, MultiControlNetModel):
1221
+ cond_prepared_videos = []
1222
+ for frame_ in conditioning_frames:
1223
+ prepared_video = self.prepare_conditioning_frames(
1224
+ video=frame_,
1225
+ width=width,
1226
+ height=height,
1227
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1228
+ num_videos_per_prompt=num_videos_per_prompt,
1229
+ device=device,
1230
+ dtype=controlnet.dtype,
1231
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1232
+ guess_mode=guess_mode,
1233
+ )
1234
+ cond_prepared_videos.append(prepared_video)
1235
+ conditioning_frames = cond_prepared_videos
1236
+ else:
1237
+ assert False
1238
+
1239
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1240
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1241
+
1242
+ # 9. Add image embeds for IP-Adapter
1243
+ added_cond_kwargs = (
1244
+ {"image_embeds": image_embeds}
1245
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1246
+ else None
1247
+ )
1248
+
1249
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
1250
+ for free_init_iter in range(num_free_init_iters):
1251
+ if self.free_init_enabled:
1252
+ latents, timesteps = self._apply_free_init(
1253
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
1254
+ )
1255
+ num_inference_steps = len(timesteps)
1256
+ # make sure to readjust timesteps based on strength
1257
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
1258
+
1259
+ self._num_timesteps = len(timesteps)
1260
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1261
+
1262
+ # 10. Denoising loop
1263
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1264
+ for i, t in enumerate(timesteps):
1265
+ if self.interrupt:
1266
+ continue
1267
+
1268
+ # expand the latents if we are doing classifier free guidance
1269
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1270
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1271
+
1272
+ if guess_mode and self.do_classifier_free_guidance:
1273
+ # Infer ControlNet only for the conditional batch.
1274
+ control_model_input = latents
1275
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1276
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1277
+ else:
1278
+ control_model_input = latent_model_input
1279
+ controlnet_prompt_embeds = prompt_embeds
1280
+
1281
+ if isinstance(controlnet_keep[i], list):
1282
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1283
+ else:
1284
+ controlnet_cond_scale = controlnet_conditioning_scale
1285
+ if isinstance(controlnet_cond_scale, list):
1286
+ controlnet_cond_scale = controlnet_cond_scale[0]
1287
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1288
+
1289
+ control_model_input = torch.transpose(control_model_input, 1, 2)
1290
+ control_model_input = control_model_input.reshape(
1291
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1292
+ )
1293
+
1294
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1295
+ control_model_input,
1296
+ t,
1297
+ encoder_hidden_states=controlnet_prompt_embeds,
1298
+ controlnet_cond=conditioning_frames,
1299
+ conditioning_scale=cond_scale,
1300
+ guess_mode=guess_mode,
1301
+ return_dict=False,
1302
+ )
1303
+
1304
+ # predict the noise residual
1305
+ noise_pred = self.unet(
1306
+ latent_model_input,
1307
+ t,
1308
+ encoder_hidden_states=prompt_embeds,
1309
+ cross_attention_kwargs=self.cross_attention_kwargs,
1310
+ added_cond_kwargs=added_cond_kwargs,
1311
+ down_block_additional_residuals=down_block_res_samples,
1312
+ mid_block_additional_residual=mid_block_res_sample,
1313
+ ).sample
1314
+
1315
+ # perform guidance
1316
+ if self.do_classifier_free_guidance:
1317
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1318
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1319
+
1320
+ # compute the previous noisy sample x_t -> x_t-1
1321
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1322
+
1323
+ if callback_on_step_end is not None:
1324
+ callback_kwargs = {}
1325
+ for k in callback_on_step_end_tensor_inputs:
1326
+ callback_kwargs[k] = locals()[k]
1327
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1328
+
1329
+ latents = callback_outputs.pop("latents", latents)
1330
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1331
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1332
+
1333
+ # call the callback, if provided
1334
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1335
+ progress_bar.update()
1336
+
1337
+ if XLA_AVAILABLE:
1338
+ xm.mark_step()
1339
+
1340
+ # 11. Post-processing
1341
+ if output_type == "latent":
1342
+ video = latents
1343
+ else:
1344
+ video_tensor = self.decode_latents(latents, decode_chunk_size)
1345
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
1346
+
1347
+ # 12. Offload all models
1348
+ self.maybe_free_model_hooks()
1349
+
1350
+ if not return_dict:
1351
+ return (video,)
1352
+
1353
+ return AnimateDiffPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/animatediff/pipeline_output.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+
8
+ from ...utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class AnimateDiffPipelineOutput(BaseOutput):
13
+ r"""
14
+ Output class for AnimateDiff pipelines.
15
+
16
+ Args:
17
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
18
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
19
+ denoised
20
+ PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
21
+ `(batch_size, num_frames, channels, height, width)`
22
+ """
23
+
24
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ is_torch_available,
8
+ is_transformers_available,
9
+ is_transformers_version,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+ try:
17
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ from ...utils.dummy_torch_and_transformers_objects import (
21
+ AudioLDMPipeline,
22
+ )
23
+
24
+ _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline})
25
+ else:
26
+ _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"]
27
+
28
+
29
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
+ try:
31
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
32
+ raise OptionalDependencyNotAvailable()
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import (
35
+ AudioLDMPipeline,
36
+ )
37
+
38
+ else:
39
+ from .pipeline_audioldm import AudioLDMPipeline
40
+ else:
41
+ import sys
42
+
43
+ sys.modules[__name__] = _LazyModule(
44
+ __name__,
45
+ globals()["__file__"],
46
+ _import_structure,
47
+ module_spec=__spec__,
48
+ )
49
+
50
+ for name, value in _dummy_objects.items():
51
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm/pipeline_audioldm.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan
22
+
23
+ from ...models import AutoencoderKL, UNet2DConditionModel
24
+ from ...schedulers import KarrasDiffusionSchedulers
25
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
26
+ from ...utils.torch_utils import randn_tensor
27
+ from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
28
+
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> from diffusers import AudioLDMPipeline
44
+ >>> import torch
45
+ >>> import scipy
46
+
47
+ >>> repo_id = "cvssp/audioldm-s-full-v2"
48
+ >>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
49
+ >>> pipe = pipe.to("cuda")
50
+
51
+ >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
52
+ >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
53
+
54
+ >>> # save the audio sample as a .wav file
55
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
56
+ ```
57
+ """
58
+
59
+
60
+ class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
61
+ r"""
62
+ Pipeline for text-to-audio generation using AudioLDM.
63
+
64
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
65
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
66
+
67
+ Args:
68
+ vae ([`AutoencoderKL`]):
69
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
70
+ text_encoder ([`~transformers.ClapTextModelWithProjection`]):
71
+ Frozen text-encoder (`ClapTextModelWithProjection`, specifically the
72
+ [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
73
+ tokenizer ([`PreTrainedTokenizer`]):
74
+ A [`~transformers.RobertaTokenizer`] to tokenize text.
75
+ unet ([`UNet2DConditionModel`]):
76
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
77
+ scheduler ([`SchedulerMixin`]):
78
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
79
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
80
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
81
+ Vocoder of class `SpeechT5HifiGan`.
82
+ """
83
+
84
+ _last_supported_version = "0.33.1"
85
+ model_cpu_offload_seq = "text_encoder->unet->vae"
86
+
87
+ def __init__(
88
+ self,
89
+ vae: AutoencoderKL,
90
+ text_encoder: ClapTextModelWithProjection,
91
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
92
+ unet: UNet2DConditionModel,
93
+ scheduler: KarrasDiffusionSchedulers,
94
+ vocoder: SpeechT5HifiGan,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.register_modules(
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ tokenizer=tokenizer,
102
+ unet=unet,
103
+ scheduler=scheduler,
104
+ vocoder=vocoder,
105
+ )
106
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
107
+
108
+ def _encode_prompt(
109
+ self,
110
+ prompt,
111
+ device,
112
+ num_waveforms_per_prompt,
113
+ do_classifier_free_guidance,
114
+ negative_prompt=None,
115
+ prompt_embeds: Optional[torch.Tensor] = None,
116
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
117
+ ):
118
+ r"""
119
+ Encodes the prompt into text encoder hidden states.
120
+
121
+ Args:
122
+ prompt (`str` or `List[str]`, *optional*):
123
+ prompt to be encoded
124
+ device (`torch.device`):
125
+ torch device
126
+ num_waveforms_per_prompt (`int`):
127
+ number of waveforms that should be generated per prompt
128
+ do_classifier_free_guidance (`bool`):
129
+ whether to use classifier free guidance or not
130
+ negative_prompt (`str` or `List[str]`, *optional*):
131
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
132
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
133
+ less than `1`).
134
+ prompt_embeds (`torch.Tensor`, *optional*):
135
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
136
+ provided, text embeddings will be generated from `prompt` input argument.
137
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
138
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
139
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
140
+ argument.
141
+ """
142
+ if prompt is not None and isinstance(prompt, str):
143
+ batch_size = 1
144
+ elif prompt is not None and isinstance(prompt, list):
145
+ batch_size = len(prompt)
146
+ else:
147
+ batch_size = prompt_embeds.shape[0]
148
+
149
+ if prompt_embeds is None:
150
+ text_inputs = self.tokenizer(
151
+ prompt,
152
+ padding="max_length",
153
+ max_length=self.tokenizer.model_max_length,
154
+ truncation=True,
155
+ return_tensors="pt",
156
+ )
157
+ text_input_ids = text_inputs.input_ids
158
+ attention_mask = text_inputs.attention_mask
159
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
160
+
161
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
162
+ text_input_ids, untruncated_ids
163
+ ):
164
+ removed_text = self.tokenizer.batch_decode(
165
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
166
+ )
167
+ logger.warning(
168
+ "The following part of your input was truncated because CLAP can only handle sequences up to"
169
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
170
+ )
171
+
172
+ prompt_embeds = self.text_encoder(
173
+ text_input_ids.to(device),
174
+ attention_mask=attention_mask.to(device),
175
+ )
176
+ prompt_embeds = prompt_embeds.text_embeds
177
+ # additional L_2 normalization over each hidden-state
178
+ prompt_embeds = F.normalize(prompt_embeds, dim=-1)
179
+
180
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
181
+
182
+ (
183
+ bs_embed,
184
+ seq_len,
185
+ ) = prompt_embeds.shape
186
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
187
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
188
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
189
+
190
+ # get unconditional embeddings for classifier free guidance
191
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
192
+ uncond_tokens: List[str]
193
+ if negative_prompt is None:
194
+ uncond_tokens = [""] * batch_size
195
+ elif type(prompt) is not type(negative_prompt):
196
+ raise TypeError(
197
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
198
+ f" {type(prompt)}."
199
+ )
200
+ elif isinstance(negative_prompt, str):
201
+ uncond_tokens = [negative_prompt]
202
+ elif batch_size != len(negative_prompt):
203
+ raise ValueError(
204
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
205
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
206
+ " the batch size of `prompt`."
207
+ )
208
+ else:
209
+ uncond_tokens = negative_prompt
210
+
211
+ max_length = prompt_embeds.shape[1]
212
+ uncond_input = self.tokenizer(
213
+ uncond_tokens,
214
+ padding="max_length",
215
+ max_length=max_length,
216
+ truncation=True,
217
+ return_tensors="pt",
218
+ )
219
+
220
+ uncond_input_ids = uncond_input.input_ids.to(device)
221
+ attention_mask = uncond_input.attention_mask.to(device)
222
+
223
+ negative_prompt_embeds = self.text_encoder(
224
+ uncond_input_ids,
225
+ attention_mask=attention_mask,
226
+ )
227
+ negative_prompt_embeds = negative_prompt_embeds.text_embeds
228
+ # additional L_2 normalization over each hidden-state
229
+ negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1)
230
+
231
+ if do_classifier_free_guidance:
232
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
233
+ seq_len = negative_prompt_embeds.shape[1]
234
+
235
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
236
+
237
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
238
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
239
+
240
+ # For classifier free guidance, we need to do two forward passes.
241
+ # Here we concatenate the unconditional and text embeddings into a single batch
242
+ # to avoid doing two forward passes
243
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
244
+
245
+ return prompt_embeds
246
+
247
+ def decode_latents(self, latents):
248
+ latents = 1 / self.vae.config.scaling_factor * latents
249
+ mel_spectrogram = self.vae.decode(latents).sample
250
+ return mel_spectrogram
251
+
252
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
253
+ if mel_spectrogram.dim() == 4:
254
+ mel_spectrogram = mel_spectrogram.squeeze(1)
255
+
256
+ waveform = self.vocoder(mel_spectrogram)
257
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
258
+ waveform = waveform.cpu().float()
259
+ return waveform
260
+
261
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
262
+ def prepare_extra_step_kwargs(self, generator, eta):
263
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
264
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
265
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
266
+ # and should be between [0, 1]
267
+
268
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
269
+ extra_step_kwargs = {}
270
+ if accepts_eta:
271
+ extra_step_kwargs["eta"] = eta
272
+
273
+ # check if the scheduler accepts generator
274
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
275
+ if accepts_generator:
276
+ extra_step_kwargs["generator"] = generator
277
+ return extra_step_kwargs
278
+
279
+ def check_inputs(
280
+ self,
281
+ prompt,
282
+ audio_length_in_s,
283
+ vocoder_upsample_factor,
284
+ callback_steps,
285
+ negative_prompt=None,
286
+ prompt_embeds=None,
287
+ negative_prompt_embeds=None,
288
+ ):
289
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
290
+ if audio_length_in_s < min_audio_length_in_s:
291
+ raise ValueError(
292
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
293
+ f"is {audio_length_in_s}."
294
+ )
295
+
296
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
297
+ raise ValueError(
298
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
299
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
300
+ f"{self.vae_scale_factor}."
301
+ )
302
+
303
+ if (callback_steps is None) or (
304
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
305
+ ):
306
+ raise ValueError(
307
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
308
+ f" {type(callback_steps)}."
309
+ )
310
+
311
+ if prompt is not None and prompt_embeds is not None:
312
+ raise ValueError(
313
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
314
+ " only forward one of the two."
315
+ )
316
+ elif prompt is None and prompt_embeds is None:
317
+ raise ValueError(
318
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
319
+ )
320
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
321
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
322
+
323
+ if negative_prompt is not None and negative_prompt_embeds is not None:
324
+ raise ValueError(
325
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
326
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
327
+ )
328
+
329
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
330
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
331
+ raise ValueError(
332
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
333
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
334
+ f" {negative_prompt_embeds.shape}."
335
+ )
336
+
337
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
338
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
339
+ shape = (
340
+ batch_size,
341
+ num_channels_latents,
342
+ int(height) // self.vae_scale_factor,
343
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
344
+ )
345
+ if isinstance(generator, list) and len(generator) != batch_size:
346
+ raise ValueError(
347
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
348
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
349
+ )
350
+
351
+ if latents is None:
352
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
353
+ else:
354
+ latents = latents.to(device)
355
+
356
+ # scale the initial noise by the standard deviation required by the scheduler
357
+ latents = latents * self.scheduler.init_noise_sigma
358
+ return latents
359
+
360
+ @torch.no_grad()
361
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
362
+ def __call__(
363
+ self,
364
+ prompt: Union[str, List[str]] = None,
365
+ audio_length_in_s: Optional[float] = None,
366
+ num_inference_steps: int = 10,
367
+ guidance_scale: float = 2.5,
368
+ negative_prompt: Optional[Union[str, List[str]]] = None,
369
+ num_waveforms_per_prompt: Optional[int] = 1,
370
+ eta: float = 0.0,
371
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
372
+ latents: Optional[torch.Tensor] = None,
373
+ prompt_embeds: Optional[torch.Tensor] = None,
374
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
375
+ return_dict: bool = True,
376
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
377
+ callback_steps: Optional[int] = 1,
378
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
379
+ output_type: Optional[str] = "np",
380
+ ):
381
+ r"""
382
+ The call function to the pipeline for generation.
383
+
384
+ Args:
385
+ prompt (`str` or `List[str]`, *optional*):
386
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
387
+ audio_length_in_s (`int`, *optional*, defaults to 5.12):
388
+ The length of the generated audio sample in seconds.
389
+ num_inference_steps (`int`, *optional*, defaults to 10):
390
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
391
+ expense of slower inference.
392
+ guidance_scale (`float`, *optional*, defaults to 2.5):
393
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
394
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
395
+ negative_prompt (`str` or `List[str]`, *optional*):
396
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
397
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
398
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
399
+ The number of waveforms to generate per prompt.
400
+ eta (`float`, *optional*, defaults to 0.0):
401
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
402
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
403
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
404
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
405
+ generation deterministic.
406
+ latents (`torch.Tensor`, *optional*):
407
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
408
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
409
+ tensor is generated by sampling using the supplied random `generator`.
410
+ prompt_embeds (`torch.Tensor`, *optional*):
411
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
412
+ provided, text embeddings are generated from the `prompt` input argument.
413
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
414
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
415
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
416
+ return_dict (`bool`, *optional*, defaults to `True`):
417
+ Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
418
+ callback (`Callable`, *optional*):
419
+ A function that calls every `callback_steps` steps during inference. The function is called with the
420
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
421
+ callback_steps (`int`, *optional*, defaults to 1):
422
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
423
+ every step.
424
+ cross_attention_kwargs (`dict`, *optional*):
425
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
426
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
427
+ output_type (`str`, *optional*, defaults to `"np"`):
428
+ The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or
429
+ `"pt"` to return a PyTorch `torch.Tensor` object.
430
+
431
+ Examples:
432
+
433
+ Returns:
434
+ [`~pipelines.AudioPipelineOutput`] or `tuple`:
435
+ If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
436
+ returned where the first element is a list with the generated audio.
437
+ """
438
+ # 0. Convert audio input length from seconds to spectrogram height
439
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
440
+
441
+ if audio_length_in_s is None:
442
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
443
+
444
+ height = int(audio_length_in_s / vocoder_upsample_factor)
445
+
446
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
447
+ if height % self.vae_scale_factor != 0:
448
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
449
+ logger.info(
450
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
451
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
452
+ f"denoising process."
453
+ )
454
+
455
+ # 1. Check inputs. Raise error if not correct
456
+ self.check_inputs(
457
+ prompt,
458
+ audio_length_in_s,
459
+ vocoder_upsample_factor,
460
+ callback_steps,
461
+ negative_prompt,
462
+ prompt_embeds,
463
+ negative_prompt_embeds,
464
+ )
465
+
466
+ # 2. Define call parameters
467
+ if prompt is not None and isinstance(prompt, str):
468
+ batch_size = 1
469
+ elif prompt is not None and isinstance(prompt, list):
470
+ batch_size = len(prompt)
471
+ else:
472
+ batch_size = prompt_embeds.shape[0]
473
+
474
+ device = self._execution_device
475
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
476
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
477
+ # corresponds to doing no classifier free guidance.
478
+ do_classifier_free_guidance = guidance_scale > 1.0
479
+
480
+ # 3. Encode input prompt
481
+ prompt_embeds = self._encode_prompt(
482
+ prompt,
483
+ device,
484
+ num_waveforms_per_prompt,
485
+ do_classifier_free_guidance,
486
+ negative_prompt,
487
+ prompt_embeds=prompt_embeds,
488
+ negative_prompt_embeds=negative_prompt_embeds,
489
+ )
490
+
491
+ # 4. Prepare timesteps
492
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
493
+ timesteps = self.scheduler.timesteps
494
+
495
+ # 5. Prepare latent variables
496
+ num_channels_latents = self.unet.config.in_channels
497
+ latents = self.prepare_latents(
498
+ batch_size * num_waveforms_per_prompt,
499
+ num_channels_latents,
500
+ height,
501
+ prompt_embeds.dtype,
502
+ device,
503
+ generator,
504
+ latents,
505
+ )
506
+
507
+ # 6. Prepare extra step kwargs
508
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
509
+
510
+ # 7. Denoising loop
511
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
512
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
513
+ for i, t in enumerate(timesteps):
514
+ # expand the latents if we are doing classifier free guidance
515
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
516
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
517
+
518
+ # predict the noise residual
519
+ noise_pred = self.unet(
520
+ latent_model_input,
521
+ t,
522
+ encoder_hidden_states=None,
523
+ class_labels=prompt_embeds,
524
+ cross_attention_kwargs=cross_attention_kwargs,
525
+ ).sample
526
+
527
+ # perform guidance
528
+ if do_classifier_free_guidance:
529
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
530
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
531
+
532
+ # compute the previous noisy sample x_t -> x_t-1
533
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
534
+
535
+ # call the callback, if provided
536
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
537
+ progress_bar.update()
538
+ if callback is not None and i % callback_steps == 0:
539
+ step_idx = i // getattr(self.scheduler, "order", 1)
540
+ callback(step_idx, t, latents)
541
+
542
+ if XLA_AVAILABLE:
543
+ xm.mark_step()
544
+
545
+ # 8. Post-processing
546
+ mel_spectrogram = self.decode_latents(latents)
547
+
548
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
549
+
550
+ audio = audio[:, :original_waveform_length]
551
+
552
+ if output_type == "np":
553
+ audio = audio.numpy()
554
+
555
+ if not return_dict:
556
+ return (audio,)
557
+
558
+ return AudioPipelineOutput(audios=audio)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ is_transformers_version,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
26
+ _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
27
+
28
+
29
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
+ try:
31
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
32
+ raise OptionalDependencyNotAvailable()
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+
36
+ else:
37
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
38
+ from .pipeline_audioldm2 import AudioLDM2Pipeline
39
+
40
+ else:
41
+ import sys
42
+
43
+ sys.modules[__name__] = _LazyModule(
44
+ __name__,
45
+ globals()["__file__"],
46
+ _import_structure,
47
+ module_spec=__spec__,
48
+ )
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/modeling_audioldm2.py ADDED
@@ -0,0 +1,1475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import UNet2DConditionLoadersMixin
24
+ from ...models.activations import get_activation
25
+ from ...models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from ...models.embeddings import (
33
+ TimestepEmbedding,
34
+ Timesteps,
35
+ )
36
+ from ...models.modeling_utils import ModelMixin
37
+ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
38
+ from ...models.transformers.transformer_2d import Transformer2DModel
39
+ from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
40
+ from ...models.unets.unet_2d_condition import UNet2DConditionOutput
41
+ from ...utils import BaseOutput, logging
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
48
+ batch_size = hidden_states.shape[0]
49
+
50
+ if attention_mask is not None:
51
+ # Add two more steps to attn mask
52
+ new_attn_mask_step = attention_mask.new_ones((batch_size, 1))
53
+ attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1)
54
+
55
+ # Add the SOS / EOS tokens at the start / end of the sequence respectively
56
+ sos_token = sos_token.expand(batch_size, 1, -1)
57
+ eos_token = eos_token.expand(batch_size, 1, -1)
58
+ hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1)
59
+ return hidden_states, attention_mask
60
+
61
+
62
+ @dataclass
63
+ class AudioLDM2ProjectionModelOutput(BaseOutput):
64
+ """
65
+ Args:
66
+ Class for AudioLDM2 projection layer's outputs.
67
+ hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
68
+ Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
69
+ encoders and subsequently concatenating them together.
70
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
71
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
72
+ for the two text encoders together. Mask values selected in `[0, 1]`:
73
+
74
+ - 1 for tokens that are **not masked**,
75
+ - 0 for tokens that are **masked**.
76
+ """
77
+
78
+ hidden_states: torch.Tensor
79
+ attention_mask: Optional[torch.LongTensor] = None
80
+
81
+
82
+ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
83
+ """
84
+ A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
85
+ embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
86
+ `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
87
+
88
+ Args:
89
+ text_encoder_dim (`int`):
90
+ Dimensionality of the text embeddings from the first text encoder (CLAP).
91
+ text_encoder_1_dim (`int`):
92
+ Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
93
+ langauge_model_dim (`int`):
94
+ Dimensionality of the text embeddings from the language model (GPT2).
95
+ """
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ text_encoder_dim,
101
+ text_encoder_1_dim,
102
+ langauge_model_dim,
103
+ use_learned_position_embedding=None,
104
+ max_seq_length=None,
105
+ ):
106
+ super().__init__()
107
+ # additional projection layers for each text encoder
108
+ self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
109
+ self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
110
+
111
+ # learnable SOS / EOS token embeddings for each text encoder
112
+ self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim))
113
+ self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim))
114
+
115
+ self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
116
+ self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
117
+
118
+ self.use_learned_position_embedding = use_learned_position_embedding
119
+
120
+ # learable positional embedding for vits encoder
121
+ if self.use_learned_position_embedding is not None:
122
+ self.learnable_positional_embedding = torch.nn.Parameter(
123
+ torch.zeros((1, text_encoder_1_dim, max_seq_length))
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ hidden_states: Optional[torch.Tensor] = None,
129
+ hidden_states_1: Optional[torch.Tensor] = None,
130
+ attention_mask: Optional[torch.LongTensor] = None,
131
+ attention_mask_1: Optional[torch.LongTensor] = None,
132
+ ):
133
+ hidden_states = self.projection(hidden_states)
134
+ hidden_states, attention_mask = add_special_tokens(
135
+ hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
136
+ )
137
+
138
+ # Add positional embedding for Vits hidden state
139
+ if self.use_learned_position_embedding is not None:
140
+ hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
141
+
142
+ hidden_states_1 = self.projection_1(hidden_states_1)
143
+ hidden_states_1, attention_mask_1 = add_special_tokens(
144
+ hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
145
+ )
146
+
147
+ # concatenate clap and t5 text encoding
148
+ hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
149
+
150
+ # concatenate attention masks
151
+ if attention_mask is None and attention_mask_1 is not None:
152
+ attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
153
+ elif attention_mask is not None and attention_mask_1 is None:
154
+ attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
155
+
156
+ if attention_mask is not None and attention_mask_1 is not None:
157
+ attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)
158
+ else:
159
+ attention_mask = None
160
+
161
+ return AudioLDM2ProjectionModelOutput(
162
+ hidden_states=hidden_states,
163
+ attention_mask=attention_mask,
164
+ )
165
+
166
+
167
+ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
168
+ r"""
169
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
170
+ shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
171
+ self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
172
+ to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
173
+
174
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
175
+ for all models (such as downloading or saving).
176
+
177
+ Parameters:
178
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
179
+ Height and width of input/output sample.
180
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
181
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
182
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
183
+ Whether to flip the sin to cos in the time embedding.
184
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
185
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
186
+ The tuple of downsample blocks to use.
187
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
188
+ Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
189
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
190
+ The tuple of upsample blocks to use.
191
+ only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
192
+ Whether to include self-attention in the basic transformer blocks, see
193
+ [`~models.attention.BasicTransformerBlock`].
194
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
195
+ The tuple of output channels for each block.
196
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
197
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
198
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
199
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
200
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
201
+ If `None`, normalization and activation layers is skipped in post-processing.
202
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
203
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
204
+ The dimension of the cross attention features.
205
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
206
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
207
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
208
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
209
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
210
+ num_attention_heads (`int`, *optional*):
211
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
212
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
213
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
214
+ class_embed_type (`str`, *optional*, defaults to `None`):
215
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
216
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
217
+ num_class_embeds (`int`, *optional*, defaults to `None`):
218
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
219
+ class conditioning with `class_embed_type` equal to `None`.
220
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
221
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
222
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
223
+ An optional override for the dimension of the projected time embedding.
224
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
225
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
226
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
227
+ timestep_post_act (`str`, *optional*, defaults to `None`):
228
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
229
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
230
+ The dimension of `cond_proj` layer in the timestep embedding.
231
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
232
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
233
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
234
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
235
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
236
+ embeddings with the class embeddings.
237
+ """
238
+
239
+ _supports_gradient_checkpointing = True
240
+
241
+ @register_to_config
242
+ def __init__(
243
+ self,
244
+ sample_size: Optional[int] = None,
245
+ in_channels: int = 4,
246
+ out_channels: int = 4,
247
+ flip_sin_to_cos: bool = True,
248
+ freq_shift: int = 0,
249
+ down_block_types: Tuple[str] = (
250
+ "CrossAttnDownBlock2D",
251
+ "CrossAttnDownBlock2D",
252
+ "CrossAttnDownBlock2D",
253
+ "DownBlock2D",
254
+ ),
255
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
256
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
257
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
258
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
259
+ layers_per_block: Union[int, Tuple[int]] = 2,
260
+ downsample_padding: int = 1,
261
+ mid_block_scale_factor: float = 1,
262
+ act_fn: str = "silu",
263
+ norm_num_groups: Optional[int] = 32,
264
+ norm_eps: float = 1e-5,
265
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
266
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
267
+ attention_head_dim: Union[int, Tuple[int]] = 8,
268
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
269
+ use_linear_projection: bool = False,
270
+ class_embed_type: Optional[str] = None,
271
+ num_class_embeds: Optional[int] = None,
272
+ upcast_attention: bool = False,
273
+ resnet_time_scale_shift: str = "default",
274
+ time_embedding_type: str = "positional",
275
+ time_embedding_dim: Optional[int] = None,
276
+ time_embedding_act_fn: Optional[str] = None,
277
+ timestep_post_act: Optional[str] = None,
278
+ time_cond_proj_dim: Optional[int] = None,
279
+ conv_in_kernel: int = 3,
280
+ conv_out_kernel: int = 3,
281
+ projection_class_embeddings_input_dim: Optional[int] = None,
282
+ class_embeddings_concat: bool = False,
283
+ ):
284
+ super().__init__()
285
+
286
+ self.sample_size = sample_size
287
+
288
+ if num_attention_heads is not None:
289
+ raise ValueError(
290
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
291
+ )
292
+
293
+ # If `num_attention_heads` is not defined (which is the case for most models)
294
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
295
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
296
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
297
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
298
+ # which is why we correct for the naming here.
299
+ num_attention_heads = num_attention_heads or attention_head_dim
300
+
301
+ # Check inputs
302
+ if len(down_block_types) != len(up_block_types):
303
+ raise ValueError(
304
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
305
+ )
306
+
307
+ if len(block_out_channels) != len(down_block_types):
308
+ raise ValueError(
309
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
310
+ )
311
+
312
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
313
+ raise ValueError(
314
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
315
+ )
316
+
317
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
318
+ raise ValueError(
319
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
320
+ )
321
+
322
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
323
+ raise ValueError(
324
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
325
+ )
326
+
327
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
328
+ raise ValueError(
329
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
330
+ )
331
+
332
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
333
+ raise ValueError(
334
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
335
+ )
336
+
337
+ # input
338
+ conv_in_padding = (conv_in_kernel - 1) // 2
339
+ self.conv_in = nn.Conv2d(
340
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
341
+ )
342
+
343
+ # time
344
+ if time_embedding_type == "positional":
345
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
346
+
347
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
348
+ timestep_input_dim = block_out_channels[0]
349
+ else:
350
+ raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
351
+
352
+ self.time_embedding = TimestepEmbedding(
353
+ timestep_input_dim,
354
+ time_embed_dim,
355
+ act_fn=act_fn,
356
+ post_act_fn=timestep_post_act,
357
+ cond_proj_dim=time_cond_proj_dim,
358
+ )
359
+
360
+ # class embedding
361
+ if class_embed_type is None and num_class_embeds is not None:
362
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
363
+ elif class_embed_type == "timestep":
364
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
365
+ elif class_embed_type == "identity":
366
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
367
+ elif class_embed_type == "projection":
368
+ if projection_class_embeddings_input_dim is None:
369
+ raise ValueError(
370
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
371
+ )
372
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
373
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
374
+ # 2. it projects from an arbitrary input dimension.
375
+ #
376
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
377
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
378
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
379
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
380
+ elif class_embed_type == "simple_projection":
381
+ if projection_class_embeddings_input_dim is None:
382
+ raise ValueError(
383
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
384
+ )
385
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
386
+ else:
387
+ self.class_embedding = None
388
+
389
+ if time_embedding_act_fn is None:
390
+ self.time_embed_act = None
391
+ else:
392
+ self.time_embed_act = get_activation(time_embedding_act_fn)
393
+
394
+ self.down_blocks = nn.ModuleList([])
395
+ self.up_blocks = nn.ModuleList([])
396
+
397
+ if isinstance(only_cross_attention, bool):
398
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
399
+
400
+ if isinstance(num_attention_heads, int):
401
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
402
+
403
+ if isinstance(cross_attention_dim, int):
404
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
405
+
406
+ if isinstance(layers_per_block, int):
407
+ layers_per_block = [layers_per_block] * len(down_block_types)
408
+
409
+ if isinstance(transformer_layers_per_block, int):
410
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
411
+
412
+ if class_embeddings_concat:
413
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
414
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
415
+ # regular time embeddings
416
+ blocks_time_embed_dim = time_embed_dim * 2
417
+ else:
418
+ blocks_time_embed_dim = time_embed_dim
419
+
420
+ # down
421
+ output_channel = block_out_channels[0]
422
+ for i, down_block_type in enumerate(down_block_types):
423
+ input_channel = output_channel
424
+ output_channel = block_out_channels[i]
425
+ is_final_block = i == len(block_out_channels) - 1
426
+
427
+ down_block = get_down_block(
428
+ down_block_type,
429
+ num_layers=layers_per_block[i],
430
+ transformer_layers_per_block=transformer_layers_per_block[i],
431
+ in_channels=input_channel,
432
+ out_channels=output_channel,
433
+ temb_channels=blocks_time_embed_dim,
434
+ add_downsample=not is_final_block,
435
+ resnet_eps=norm_eps,
436
+ resnet_act_fn=act_fn,
437
+ resnet_groups=norm_num_groups,
438
+ cross_attention_dim=cross_attention_dim[i],
439
+ num_attention_heads=num_attention_heads[i],
440
+ downsample_padding=downsample_padding,
441
+ use_linear_projection=use_linear_projection,
442
+ only_cross_attention=only_cross_attention[i],
443
+ upcast_attention=upcast_attention,
444
+ resnet_time_scale_shift=resnet_time_scale_shift,
445
+ )
446
+ self.down_blocks.append(down_block)
447
+
448
+ # mid
449
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
450
+ self.mid_block = UNetMidBlock2DCrossAttn(
451
+ transformer_layers_per_block=transformer_layers_per_block[-1],
452
+ in_channels=block_out_channels[-1],
453
+ temb_channels=blocks_time_embed_dim,
454
+ resnet_eps=norm_eps,
455
+ resnet_act_fn=act_fn,
456
+ output_scale_factor=mid_block_scale_factor,
457
+ resnet_time_scale_shift=resnet_time_scale_shift,
458
+ cross_attention_dim=cross_attention_dim[-1],
459
+ num_attention_heads=num_attention_heads[-1],
460
+ resnet_groups=norm_num_groups,
461
+ use_linear_projection=use_linear_projection,
462
+ upcast_attention=upcast_attention,
463
+ )
464
+ else:
465
+ raise ValueError(
466
+ f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
467
+ )
468
+
469
+ # count how many layers upsample the images
470
+ self.num_upsamplers = 0
471
+
472
+ # up
473
+ reversed_block_out_channels = list(reversed(block_out_channels))
474
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
475
+ reversed_layers_per_block = list(reversed(layers_per_block))
476
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
477
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
478
+ only_cross_attention = list(reversed(only_cross_attention))
479
+
480
+ output_channel = reversed_block_out_channels[0]
481
+ for i, up_block_type in enumerate(up_block_types):
482
+ is_final_block = i == len(block_out_channels) - 1
483
+
484
+ prev_output_channel = output_channel
485
+ output_channel = reversed_block_out_channels[i]
486
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
487
+
488
+ # add upsample block for all BUT final layer
489
+ if not is_final_block:
490
+ add_upsample = True
491
+ self.num_upsamplers += 1
492
+ else:
493
+ add_upsample = False
494
+
495
+ up_block = get_up_block(
496
+ up_block_type,
497
+ num_layers=reversed_layers_per_block[i] + 1,
498
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
499
+ in_channels=input_channel,
500
+ out_channels=output_channel,
501
+ prev_output_channel=prev_output_channel,
502
+ temb_channels=blocks_time_embed_dim,
503
+ add_upsample=add_upsample,
504
+ resnet_eps=norm_eps,
505
+ resnet_act_fn=act_fn,
506
+ resnet_groups=norm_num_groups,
507
+ cross_attention_dim=reversed_cross_attention_dim[i],
508
+ num_attention_heads=reversed_num_attention_heads[i],
509
+ use_linear_projection=use_linear_projection,
510
+ only_cross_attention=only_cross_attention[i],
511
+ upcast_attention=upcast_attention,
512
+ resnet_time_scale_shift=resnet_time_scale_shift,
513
+ )
514
+ self.up_blocks.append(up_block)
515
+ prev_output_channel = output_channel
516
+
517
+ # out
518
+ if norm_num_groups is not None:
519
+ self.conv_norm_out = nn.GroupNorm(
520
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
521
+ )
522
+
523
+ self.conv_act = get_activation(act_fn)
524
+
525
+ else:
526
+ self.conv_norm_out = None
527
+ self.conv_act = None
528
+
529
+ conv_out_padding = (conv_out_kernel - 1) // 2
530
+ self.conv_out = nn.Conv2d(
531
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
532
+ )
533
+
534
+ @property
535
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
536
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
537
+ r"""
538
+ Returns:
539
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
540
+ indexed by its weight name.
541
+ """
542
+ # set recursively
543
+ processors = {}
544
+
545
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
546
+ if hasattr(module, "get_processor"):
547
+ processors[f"{name}.processor"] = module.get_processor()
548
+
549
+ for sub_name, child in module.named_children():
550
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
551
+
552
+ return processors
553
+
554
+ for name, module in self.named_children():
555
+ fn_recursive_add_processors(name, module, processors)
556
+
557
+ return processors
558
+
559
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
560
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
561
+ r"""
562
+ Sets the attention processor to use to compute attention.
563
+
564
+ Parameters:
565
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
566
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
567
+ for **all** `Attention` layers.
568
+
569
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
570
+ processor. This is strongly recommended when setting trainable attention processors.
571
+
572
+ """
573
+ count = len(self.attn_processors.keys())
574
+
575
+ if isinstance(processor, dict) and len(processor) != count:
576
+ raise ValueError(
577
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
578
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
579
+ )
580
+
581
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
582
+ if hasattr(module, "set_processor"):
583
+ if not isinstance(processor, dict):
584
+ module.set_processor(processor)
585
+ else:
586
+ module.set_processor(processor.pop(f"{name}.processor"))
587
+
588
+ for sub_name, child in module.named_children():
589
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
590
+
591
+ for name, module in self.named_children():
592
+ fn_recursive_attn_processor(name, module, processor)
593
+
594
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
595
+ def set_default_attn_processor(self):
596
+ """
597
+ Disables custom attention processors and sets the default attention implementation.
598
+ """
599
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
600
+ processor = AttnAddedKVProcessor()
601
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
602
+ processor = AttnProcessor()
603
+ else:
604
+ raise ValueError(
605
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
606
+ )
607
+
608
+ self.set_attn_processor(processor)
609
+
610
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
611
+ def set_attention_slice(self, slice_size):
612
+ r"""
613
+ Enable sliced attention computation.
614
+
615
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
616
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
617
+
618
+ Args:
619
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
620
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
621
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
622
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
623
+ must be a multiple of `slice_size`.
624
+ """
625
+ sliceable_head_dims = []
626
+
627
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
628
+ if hasattr(module, "set_attention_slice"):
629
+ sliceable_head_dims.append(module.sliceable_head_dim)
630
+
631
+ for child in module.children():
632
+ fn_recursive_retrieve_sliceable_dims(child)
633
+
634
+ # retrieve number of attention layers
635
+ for module in self.children():
636
+ fn_recursive_retrieve_sliceable_dims(module)
637
+
638
+ num_sliceable_layers = len(sliceable_head_dims)
639
+
640
+ if slice_size == "auto":
641
+ # half the attention head size is usually a good trade-off between
642
+ # speed and memory
643
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
644
+ elif slice_size == "max":
645
+ # make smallest slice possible
646
+ slice_size = num_sliceable_layers * [1]
647
+
648
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
649
+
650
+ if len(slice_size) != len(sliceable_head_dims):
651
+ raise ValueError(
652
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
653
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
654
+ )
655
+
656
+ for i in range(len(slice_size)):
657
+ size = slice_size[i]
658
+ dim = sliceable_head_dims[i]
659
+ if size is not None and size > dim:
660
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
661
+
662
+ # Recursively walk through all the children.
663
+ # Any children which exposes the set_attention_slice method
664
+ # gets the message
665
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
666
+ if hasattr(module, "set_attention_slice"):
667
+ module.set_attention_slice(slice_size.pop())
668
+
669
+ for child in module.children():
670
+ fn_recursive_set_attention_slice(child, slice_size)
671
+
672
+ reversed_slice_size = list(reversed(slice_size))
673
+ for module in self.children():
674
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
675
+
676
+ def forward(
677
+ self,
678
+ sample: torch.Tensor,
679
+ timestep: Union[torch.Tensor, float, int],
680
+ encoder_hidden_states: torch.Tensor,
681
+ class_labels: Optional[torch.Tensor] = None,
682
+ timestep_cond: Optional[torch.Tensor] = None,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
685
+ encoder_attention_mask: Optional[torch.Tensor] = None,
686
+ return_dict: bool = True,
687
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
688
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
689
+ ) -> Union[UNet2DConditionOutput, Tuple]:
690
+ r"""
691
+ The [`AudioLDM2UNet2DConditionModel`] forward method.
692
+
693
+ Args:
694
+ sample (`torch.Tensor`):
695
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
696
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
697
+ encoder_hidden_states (`torch.Tensor`):
698
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
699
+ encoder_attention_mask (`torch.Tensor`):
700
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
701
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
702
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
703
+ return_dict (`bool`, *optional*, defaults to `True`):
704
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
705
+ tuple.
706
+ cross_attention_kwargs (`dict`, *optional*):
707
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
708
+ encoder_hidden_states_1 (`torch.Tensor`, *optional*):
709
+ A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
710
+ used to condition the model on a different set of embeddings to `encoder_hidden_states`.
711
+ encoder_attention_mask_1 (`torch.Tensor`, *optional*):
712
+ A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
713
+ If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
714
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
715
+
716
+ Returns:
717
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
718
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
719
+ otherwise a `tuple` is returned where the first element is the sample tensor.
720
+ """
721
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
722
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
723
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
724
+ # on the fly if necessary.
725
+ default_overall_up_factor = 2**self.num_upsamplers
726
+
727
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
728
+ forward_upsample_size = False
729
+ upsample_size = None
730
+
731
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
732
+ logger.info("Forward upsample size to force interpolation output size.")
733
+ forward_upsample_size = True
734
+
735
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
736
+ # expects mask of shape:
737
+ # [batch, key_tokens]
738
+ # adds singleton query_tokens dimension:
739
+ # [batch, 1, key_tokens]
740
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
741
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
742
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
743
+ if attention_mask is not None:
744
+ # assume that mask is expressed as:
745
+ # (1 = keep, 0 = discard)
746
+ # convert mask into a bias that can be added to attention scores:
747
+ # (keep = +0, discard = -10000.0)
748
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
749
+ attention_mask = attention_mask.unsqueeze(1)
750
+
751
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
752
+ if encoder_attention_mask is not None:
753
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
754
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
755
+
756
+ if encoder_attention_mask_1 is not None:
757
+ encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
758
+ encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
759
+
760
+ # 1. time
761
+ timesteps = timestep
762
+ if not torch.is_tensor(timesteps):
763
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
764
+ # This would be a good case for the `match` statement (Python 3.10+)
765
+ is_mps = sample.device.type == "mps"
766
+ is_npu = sample.device.type == "npu"
767
+ if isinstance(timestep, float):
768
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
769
+ else:
770
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
771
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
772
+ elif len(timesteps.shape) == 0:
773
+ timesteps = timesteps[None].to(sample.device)
774
+
775
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
776
+ timesteps = timesteps.expand(sample.shape[0])
777
+
778
+ t_emb = self.time_proj(timesteps)
779
+
780
+ # `Timesteps` does not contain any weights and will always return f32 tensors
781
+ # but time_embedding might actually be running in fp16. so we need to cast here.
782
+ # there might be better ways to encapsulate this.
783
+ t_emb = t_emb.to(dtype=sample.dtype)
784
+
785
+ emb = self.time_embedding(t_emb, timestep_cond)
786
+ aug_emb = None
787
+
788
+ if self.class_embedding is not None:
789
+ if class_labels is None:
790
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
791
+
792
+ if self.config.class_embed_type == "timestep":
793
+ class_labels = self.time_proj(class_labels)
794
+
795
+ # `Timesteps` does not contain any weights and will always return f32 tensors
796
+ # there might be better ways to encapsulate this.
797
+ class_labels = class_labels.to(dtype=sample.dtype)
798
+
799
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
800
+
801
+ if self.config.class_embeddings_concat:
802
+ emb = torch.cat([emb, class_emb], dim=-1)
803
+ else:
804
+ emb = emb + class_emb
805
+
806
+ emb = emb + aug_emb if aug_emb is not None else emb
807
+
808
+ if self.time_embed_act is not None:
809
+ emb = self.time_embed_act(emb)
810
+
811
+ # 2. pre-process
812
+ sample = self.conv_in(sample)
813
+
814
+ # 3. down
815
+ down_block_res_samples = (sample,)
816
+ for downsample_block in self.down_blocks:
817
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
818
+ sample, res_samples = downsample_block(
819
+ hidden_states=sample,
820
+ temb=emb,
821
+ encoder_hidden_states=encoder_hidden_states,
822
+ attention_mask=attention_mask,
823
+ cross_attention_kwargs=cross_attention_kwargs,
824
+ encoder_attention_mask=encoder_attention_mask,
825
+ encoder_hidden_states_1=encoder_hidden_states_1,
826
+ encoder_attention_mask_1=encoder_attention_mask_1,
827
+ )
828
+ else:
829
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
830
+
831
+ down_block_res_samples += res_samples
832
+
833
+ # 4. mid
834
+ if self.mid_block is not None:
835
+ sample = self.mid_block(
836
+ sample,
837
+ emb,
838
+ encoder_hidden_states=encoder_hidden_states,
839
+ attention_mask=attention_mask,
840
+ cross_attention_kwargs=cross_attention_kwargs,
841
+ encoder_attention_mask=encoder_attention_mask,
842
+ encoder_hidden_states_1=encoder_hidden_states_1,
843
+ encoder_attention_mask_1=encoder_attention_mask_1,
844
+ )
845
+
846
+ # 5. up
847
+ for i, upsample_block in enumerate(self.up_blocks):
848
+ is_final_block = i == len(self.up_blocks) - 1
849
+
850
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
851
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
852
+
853
+ # if we have not reached the final block and need to forward the
854
+ # upsample size, we do it here
855
+ if not is_final_block and forward_upsample_size:
856
+ upsample_size = down_block_res_samples[-1].shape[2:]
857
+
858
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
859
+ sample = upsample_block(
860
+ hidden_states=sample,
861
+ temb=emb,
862
+ res_hidden_states_tuple=res_samples,
863
+ encoder_hidden_states=encoder_hidden_states,
864
+ cross_attention_kwargs=cross_attention_kwargs,
865
+ upsample_size=upsample_size,
866
+ attention_mask=attention_mask,
867
+ encoder_attention_mask=encoder_attention_mask,
868
+ encoder_hidden_states_1=encoder_hidden_states_1,
869
+ encoder_attention_mask_1=encoder_attention_mask_1,
870
+ )
871
+ else:
872
+ sample = upsample_block(
873
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
874
+ )
875
+
876
+ # 6. post-process
877
+ if self.conv_norm_out:
878
+ sample = self.conv_norm_out(sample)
879
+ sample = self.conv_act(sample)
880
+ sample = self.conv_out(sample)
881
+
882
+ if not return_dict:
883
+ return (sample,)
884
+
885
+ return UNet2DConditionOutput(sample=sample)
886
+
887
+
888
+ def get_down_block(
889
+ down_block_type,
890
+ num_layers,
891
+ in_channels,
892
+ out_channels,
893
+ temb_channels,
894
+ add_downsample,
895
+ resnet_eps,
896
+ resnet_act_fn,
897
+ transformer_layers_per_block=1,
898
+ num_attention_heads=None,
899
+ resnet_groups=None,
900
+ cross_attention_dim=None,
901
+ downsample_padding=None,
902
+ use_linear_projection=False,
903
+ only_cross_attention=False,
904
+ upcast_attention=False,
905
+ resnet_time_scale_shift="default",
906
+ ):
907
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
908
+ if down_block_type == "DownBlock2D":
909
+ return DownBlock2D(
910
+ num_layers=num_layers,
911
+ in_channels=in_channels,
912
+ out_channels=out_channels,
913
+ temb_channels=temb_channels,
914
+ add_downsample=add_downsample,
915
+ resnet_eps=resnet_eps,
916
+ resnet_act_fn=resnet_act_fn,
917
+ resnet_groups=resnet_groups,
918
+ downsample_padding=downsample_padding,
919
+ resnet_time_scale_shift=resnet_time_scale_shift,
920
+ )
921
+ elif down_block_type == "CrossAttnDownBlock2D":
922
+ if cross_attention_dim is None:
923
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
924
+ return CrossAttnDownBlock2D(
925
+ num_layers=num_layers,
926
+ transformer_layers_per_block=transformer_layers_per_block,
927
+ in_channels=in_channels,
928
+ out_channels=out_channels,
929
+ temb_channels=temb_channels,
930
+ add_downsample=add_downsample,
931
+ resnet_eps=resnet_eps,
932
+ resnet_act_fn=resnet_act_fn,
933
+ resnet_groups=resnet_groups,
934
+ downsample_padding=downsample_padding,
935
+ cross_attention_dim=cross_attention_dim,
936
+ num_attention_heads=num_attention_heads,
937
+ use_linear_projection=use_linear_projection,
938
+ only_cross_attention=only_cross_attention,
939
+ upcast_attention=upcast_attention,
940
+ resnet_time_scale_shift=resnet_time_scale_shift,
941
+ )
942
+ raise ValueError(f"{down_block_type} does not exist.")
943
+
944
+
945
+ def get_up_block(
946
+ up_block_type,
947
+ num_layers,
948
+ in_channels,
949
+ out_channels,
950
+ prev_output_channel,
951
+ temb_channels,
952
+ add_upsample,
953
+ resnet_eps,
954
+ resnet_act_fn,
955
+ transformer_layers_per_block=1,
956
+ num_attention_heads=None,
957
+ resnet_groups=None,
958
+ cross_attention_dim=None,
959
+ use_linear_projection=False,
960
+ only_cross_attention=False,
961
+ upcast_attention=False,
962
+ resnet_time_scale_shift="default",
963
+ ):
964
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
965
+ if up_block_type == "UpBlock2D":
966
+ return UpBlock2D(
967
+ num_layers=num_layers,
968
+ in_channels=in_channels,
969
+ out_channels=out_channels,
970
+ prev_output_channel=prev_output_channel,
971
+ temb_channels=temb_channels,
972
+ add_upsample=add_upsample,
973
+ resnet_eps=resnet_eps,
974
+ resnet_act_fn=resnet_act_fn,
975
+ resnet_groups=resnet_groups,
976
+ resnet_time_scale_shift=resnet_time_scale_shift,
977
+ )
978
+ elif up_block_type == "CrossAttnUpBlock2D":
979
+ if cross_attention_dim is None:
980
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
981
+ return CrossAttnUpBlock2D(
982
+ num_layers=num_layers,
983
+ transformer_layers_per_block=transformer_layers_per_block,
984
+ in_channels=in_channels,
985
+ out_channels=out_channels,
986
+ prev_output_channel=prev_output_channel,
987
+ temb_channels=temb_channels,
988
+ add_upsample=add_upsample,
989
+ resnet_eps=resnet_eps,
990
+ resnet_act_fn=resnet_act_fn,
991
+ resnet_groups=resnet_groups,
992
+ cross_attention_dim=cross_attention_dim,
993
+ num_attention_heads=num_attention_heads,
994
+ use_linear_projection=use_linear_projection,
995
+ only_cross_attention=only_cross_attention,
996
+ upcast_attention=upcast_attention,
997
+ resnet_time_scale_shift=resnet_time_scale_shift,
998
+ )
999
+ raise ValueError(f"{up_block_type} does not exist.")
1000
+
1001
+
1002
+ class CrossAttnDownBlock2D(nn.Module):
1003
+ def __init__(
1004
+ self,
1005
+ in_channels: int,
1006
+ out_channels: int,
1007
+ temb_channels: int,
1008
+ dropout: float = 0.0,
1009
+ num_layers: int = 1,
1010
+ transformer_layers_per_block: int = 1,
1011
+ resnet_eps: float = 1e-6,
1012
+ resnet_time_scale_shift: str = "default",
1013
+ resnet_act_fn: str = "swish",
1014
+ resnet_groups: int = 32,
1015
+ resnet_pre_norm: bool = True,
1016
+ num_attention_heads=1,
1017
+ cross_attention_dim=1280,
1018
+ output_scale_factor=1.0,
1019
+ downsample_padding=1,
1020
+ add_downsample=True,
1021
+ use_linear_projection=False,
1022
+ only_cross_attention=False,
1023
+ upcast_attention=False,
1024
+ ):
1025
+ super().__init__()
1026
+ resnets = []
1027
+ attentions = []
1028
+
1029
+ self.has_cross_attention = True
1030
+ self.num_attention_heads = num_attention_heads
1031
+
1032
+ if isinstance(cross_attention_dim, int):
1033
+ cross_attention_dim = (cross_attention_dim,)
1034
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1035
+ raise ValueError(
1036
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1037
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1038
+ )
1039
+ self.cross_attention_dim = cross_attention_dim
1040
+
1041
+ for i in range(num_layers):
1042
+ in_channels = in_channels if i == 0 else out_channels
1043
+ resnets.append(
1044
+ ResnetBlock2D(
1045
+ in_channels=in_channels,
1046
+ out_channels=out_channels,
1047
+ temb_channels=temb_channels,
1048
+ eps=resnet_eps,
1049
+ groups=resnet_groups,
1050
+ dropout=dropout,
1051
+ time_embedding_norm=resnet_time_scale_shift,
1052
+ non_linearity=resnet_act_fn,
1053
+ output_scale_factor=output_scale_factor,
1054
+ pre_norm=resnet_pre_norm,
1055
+ )
1056
+ )
1057
+ for j in range(len(cross_attention_dim)):
1058
+ attentions.append(
1059
+ Transformer2DModel(
1060
+ num_attention_heads,
1061
+ out_channels // num_attention_heads,
1062
+ in_channels=out_channels,
1063
+ num_layers=transformer_layers_per_block,
1064
+ cross_attention_dim=cross_attention_dim[j],
1065
+ norm_num_groups=resnet_groups,
1066
+ use_linear_projection=use_linear_projection,
1067
+ only_cross_attention=only_cross_attention,
1068
+ upcast_attention=upcast_attention,
1069
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1070
+ )
1071
+ )
1072
+ self.attentions = nn.ModuleList(attentions)
1073
+ self.resnets = nn.ModuleList(resnets)
1074
+
1075
+ if add_downsample:
1076
+ self.downsamplers = nn.ModuleList(
1077
+ [
1078
+ Downsample2D(
1079
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1080
+ )
1081
+ ]
1082
+ )
1083
+ else:
1084
+ self.downsamplers = None
1085
+
1086
+ self.gradient_checkpointing = False
1087
+
1088
+ def forward(
1089
+ self,
1090
+ hidden_states: torch.Tensor,
1091
+ temb: Optional[torch.Tensor] = None,
1092
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1093
+ attention_mask: Optional[torch.Tensor] = None,
1094
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1095
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1096
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1097
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1098
+ ):
1099
+ output_states = ()
1100
+ num_layers = len(self.resnets)
1101
+ num_attention_per_layer = len(self.attentions) // num_layers
1102
+
1103
+ encoder_hidden_states_1 = (
1104
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1105
+ )
1106
+ encoder_attention_mask_1 = (
1107
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1108
+ )
1109
+
1110
+ for i in range(num_layers):
1111
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1112
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
1113
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1114
+ if cross_attention_dim is not None and idx <= 1:
1115
+ forward_encoder_hidden_states = encoder_hidden_states
1116
+ forward_encoder_attention_mask = encoder_attention_mask
1117
+ elif cross_attention_dim is not None and idx > 1:
1118
+ forward_encoder_hidden_states = encoder_hidden_states_1
1119
+ forward_encoder_attention_mask = encoder_attention_mask_1
1120
+ else:
1121
+ forward_encoder_hidden_states = None
1122
+ forward_encoder_attention_mask = None
1123
+ hidden_states = self._gradient_checkpointing_func(
1124
+ self.attentions[i * num_attention_per_layer + idx],
1125
+ hidden_states,
1126
+ forward_encoder_hidden_states,
1127
+ None, # timestep
1128
+ None, # class_labels
1129
+ cross_attention_kwargs,
1130
+ attention_mask,
1131
+ forward_encoder_attention_mask,
1132
+ )[0]
1133
+ else:
1134
+ hidden_states = self.resnets[i](hidden_states, temb)
1135
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1136
+ if cross_attention_dim is not None and idx <= 1:
1137
+ forward_encoder_hidden_states = encoder_hidden_states
1138
+ forward_encoder_attention_mask = encoder_attention_mask
1139
+ elif cross_attention_dim is not None and idx > 1:
1140
+ forward_encoder_hidden_states = encoder_hidden_states_1
1141
+ forward_encoder_attention_mask = encoder_attention_mask_1
1142
+ else:
1143
+ forward_encoder_hidden_states = None
1144
+ forward_encoder_attention_mask = None
1145
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1146
+ hidden_states,
1147
+ attention_mask=attention_mask,
1148
+ encoder_hidden_states=forward_encoder_hidden_states,
1149
+ encoder_attention_mask=forward_encoder_attention_mask,
1150
+ return_dict=False,
1151
+ )[0]
1152
+
1153
+ output_states = output_states + (hidden_states,)
1154
+
1155
+ if self.downsamplers is not None:
1156
+ for downsampler in self.downsamplers:
1157
+ hidden_states = downsampler(hidden_states)
1158
+
1159
+ output_states = output_states + (hidden_states,)
1160
+
1161
+ return hidden_states, output_states
1162
+
1163
+
1164
+ class UNetMidBlock2DCrossAttn(nn.Module):
1165
+ def __init__(
1166
+ self,
1167
+ in_channels: int,
1168
+ temb_channels: int,
1169
+ dropout: float = 0.0,
1170
+ num_layers: int = 1,
1171
+ transformer_layers_per_block: int = 1,
1172
+ resnet_eps: float = 1e-6,
1173
+ resnet_time_scale_shift: str = "default",
1174
+ resnet_act_fn: str = "swish",
1175
+ resnet_groups: int = 32,
1176
+ resnet_pre_norm: bool = True,
1177
+ num_attention_heads=1,
1178
+ output_scale_factor=1.0,
1179
+ cross_attention_dim=1280,
1180
+ use_linear_projection=False,
1181
+ upcast_attention=False,
1182
+ ):
1183
+ super().__init__()
1184
+
1185
+ self.has_cross_attention = True
1186
+ self.num_attention_heads = num_attention_heads
1187
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1188
+
1189
+ if isinstance(cross_attention_dim, int):
1190
+ cross_attention_dim = (cross_attention_dim,)
1191
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1192
+ raise ValueError(
1193
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1194
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1195
+ )
1196
+ self.cross_attention_dim = cross_attention_dim
1197
+
1198
+ # there is always at least one resnet
1199
+ resnets = [
1200
+ ResnetBlock2D(
1201
+ in_channels=in_channels,
1202
+ out_channels=in_channels,
1203
+ temb_channels=temb_channels,
1204
+ eps=resnet_eps,
1205
+ groups=resnet_groups,
1206
+ dropout=dropout,
1207
+ time_embedding_norm=resnet_time_scale_shift,
1208
+ non_linearity=resnet_act_fn,
1209
+ output_scale_factor=output_scale_factor,
1210
+ pre_norm=resnet_pre_norm,
1211
+ )
1212
+ ]
1213
+ attentions = []
1214
+
1215
+ for i in range(num_layers):
1216
+ for j in range(len(cross_attention_dim)):
1217
+ attentions.append(
1218
+ Transformer2DModel(
1219
+ num_attention_heads,
1220
+ in_channels // num_attention_heads,
1221
+ in_channels=in_channels,
1222
+ num_layers=transformer_layers_per_block,
1223
+ cross_attention_dim=cross_attention_dim[j],
1224
+ norm_num_groups=resnet_groups,
1225
+ use_linear_projection=use_linear_projection,
1226
+ upcast_attention=upcast_attention,
1227
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1228
+ )
1229
+ )
1230
+ resnets.append(
1231
+ ResnetBlock2D(
1232
+ in_channels=in_channels,
1233
+ out_channels=in_channels,
1234
+ temb_channels=temb_channels,
1235
+ eps=resnet_eps,
1236
+ groups=resnet_groups,
1237
+ dropout=dropout,
1238
+ time_embedding_norm=resnet_time_scale_shift,
1239
+ non_linearity=resnet_act_fn,
1240
+ output_scale_factor=output_scale_factor,
1241
+ pre_norm=resnet_pre_norm,
1242
+ )
1243
+ )
1244
+
1245
+ self.attentions = nn.ModuleList(attentions)
1246
+ self.resnets = nn.ModuleList(resnets)
1247
+
1248
+ self.gradient_checkpointing = False
1249
+
1250
+ def forward(
1251
+ self,
1252
+ hidden_states: torch.Tensor,
1253
+ temb: Optional[torch.Tensor] = None,
1254
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1255
+ attention_mask: Optional[torch.Tensor] = None,
1256
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1257
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1258
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1259
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1260
+ ) -> torch.Tensor:
1261
+ hidden_states = self.resnets[0](hidden_states, temb)
1262
+ num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
1263
+
1264
+ encoder_hidden_states_1 = (
1265
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1266
+ )
1267
+ encoder_attention_mask_1 = (
1268
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1269
+ )
1270
+
1271
+ for i in range(len(self.resnets[1:])):
1272
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1273
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1274
+ if cross_attention_dim is not None and idx <= 1:
1275
+ forward_encoder_hidden_states = encoder_hidden_states
1276
+ forward_encoder_attention_mask = encoder_attention_mask
1277
+ elif cross_attention_dim is not None and idx > 1:
1278
+ forward_encoder_hidden_states = encoder_hidden_states_1
1279
+ forward_encoder_attention_mask = encoder_attention_mask_1
1280
+ else:
1281
+ forward_encoder_hidden_states = None
1282
+ forward_encoder_attention_mask = None
1283
+ hidden_states = self._gradient_checkpointing_func(
1284
+ self.attentions[i * num_attention_per_layer + idx],
1285
+ hidden_states,
1286
+ forward_encoder_hidden_states,
1287
+ None, # timestep
1288
+ None, # class_labels
1289
+ cross_attention_kwargs,
1290
+ attention_mask,
1291
+ forward_encoder_attention_mask,
1292
+ )[0]
1293
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
1294
+ else:
1295
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1296
+ if cross_attention_dim is not None and idx <= 1:
1297
+ forward_encoder_hidden_states = encoder_hidden_states
1298
+ forward_encoder_attention_mask = encoder_attention_mask
1299
+ elif cross_attention_dim is not None and idx > 1:
1300
+ forward_encoder_hidden_states = encoder_hidden_states_1
1301
+ forward_encoder_attention_mask = encoder_attention_mask_1
1302
+ else:
1303
+ forward_encoder_hidden_states = None
1304
+ forward_encoder_attention_mask = None
1305
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1306
+ hidden_states,
1307
+ attention_mask=attention_mask,
1308
+ encoder_hidden_states=forward_encoder_hidden_states,
1309
+ encoder_attention_mask=forward_encoder_attention_mask,
1310
+ return_dict=False,
1311
+ )[0]
1312
+
1313
+ hidden_states = self.resnets[i + 1](hidden_states, temb)
1314
+
1315
+ return hidden_states
1316
+
1317
+
1318
+ class CrossAttnUpBlock2D(nn.Module):
1319
+ def __init__(
1320
+ self,
1321
+ in_channels: int,
1322
+ out_channels: int,
1323
+ prev_output_channel: int,
1324
+ temb_channels: int,
1325
+ dropout: float = 0.0,
1326
+ num_layers: int = 1,
1327
+ transformer_layers_per_block: int = 1,
1328
+ resnet_eps: float = 1e-6,
1329
+ resnet_time_scale_shift: str = "default",
1330
+ resnet_act_fn: str = "swish",
1331
+ resnet_groups: int = 32,
1332
+ resnet_pre_norm: bool = True,
1333
+ num_attention_heads=1,
1334
+ cross_attention_dim=1280,
1335
+ output_scale_factor=1.0,
1336
+ add_upsample=True,
1337
+ use_linear_projection=False,
1338
+ only_cross_attention=False,
1339
+ upcast_attention=False,
1340
+ ):
1341
+ super().__init__()
1342
+ resnets = []
1343
+ attentions = []
1344
+
1345
+ self.has_cross_attention = True
1346
+ self.num_attention_heads = num_attention_heads
1347
+
1348
+ if isinstance(cross_attention_dim, int):
1349
+ cross_attention_dim = (cross_attention_dim,)
1350
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1351
+ raise ValueError(
1352
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1353
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1354
+ )
1355
+ self.cross_attention_dim = cross_attention_dim
1356
+
1357
+ for i in range(num_layers):
1358
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1359
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1360
+
1361
+ resnets.append(
1362
+ ResnetBlock2D(
1363
+ in_channels=resnet_in_channels + res_skip_channels,
1364
+ out_channels=out_channels,
1365
+ temb_channels=temb_channels,
1366
+ eps=resnet_eps,
1367
+ groups=resnet_groups,
1368
+ dropout=dropout,
1369
+ time_embedding_norm=resnet_time_scale_shift,
1370
+ non_linearity=resnet_act_fn,
1371
+ output_scale_factor=output_scale_factor,
1372
+ pre_norm=resnet_pre_norm,
1373
+ )
1374
+ )
1375
+ for j in range(len(cross_attention_dim)):
1376
+ attentions.append(
1377
+ Transformer2DModel(
1378
+ num_attention_heads,
1379
+ out_channels // num_attention_heads,
1380
+ in_channels=out_channels,
1381
+ num_layers=transformer_layers_per_block,
1382
+ cross_attention_dim=cross_attention_dim[j],
1383
+ norm_num_groups=resnet_groups,
1384
+ use_linear_projection=use_linear_projection,
1385
+ only_cross_attention=only_cross_attention,
1386
+ upcast_attention=upcast_attention,
1387
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1388
+ )
1389
+ )
1390
+ self.attentions = nn.ModuleList(attentions)
1391
+ self.resnets = nn.ModuleList(resnets)
1392
+
1393
+ if add_upsample:
1394
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1395
+ else:
1396
+ self.upsamplers = None
1397
+
1398
+ self.gradient_checkpointing = False
1399
+
1400
+ def forward(
1401
+ self,
1402
+ hidden_states: torch.Tensor,
1403
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1404
+ temb: Optional[torch.Tensor] = None,
1405
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1406
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1407
+ upsample_size: Optional[int] = None,
1408
+ attention_mask: Optional[torch.Tensor] = None,
1409
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1410
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
1411
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
1412
+ ):
1413
+ num_layers = len(self.resnets)
1414
+ num_attention_per_layer = len(self.attentions) // num_layers
1415
+
1416
+ encoder_hidden_states_1 = (
1417
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1418
+ )
1419
+ encoder_attention_mask_1 = (
1420
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1421
+ )
1422
+
1423
+ for i in range(num_layers):
1424
+ # pop res hidden states
1425
+ res_hidden_states = res_hidden_states_tuple[-1]
1426
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1427
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1428
+
1429
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1430
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
1431
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1432
+ if cross_attention_dim is not None and idx <= 1:
1433
+ forward_encoder_hidden_states = encoder_hidden_states
1434
+ forward_encoder_attention_mask = encoder_attention_mask
1435
+ elif cross_attention_dim is not None and idx > 1:
1436
+ forward_encoder_hidden_states = encoder_hidden_states_1
1437
+ forward_encoder_attention_mask = encoder_attention_mask_1
1438
+ else:
1439
+ forward_encoder_hidden_states = None
1440
+ forward_encoder_attention_mask = None
1441
+ hidden_states = self._gradient_checkpointing_func(
1442
+ self.attentions[i * num_attention_per_layer + idx],
1443
+ hidden_states,
1444
+ forward_encoder_hidden_states,
1445
+ None, # timestep
1446
+ None, # class_labels
1447
+ cross_attention_kwargs,
1448
+ attention_mask,
1449
+ forward_encoder_attention_mask,
1450
+ )[0]
1451
+ else:
1452
+ hidden_states = self.resnets[i](hidden_states, temb)
1453
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1454
+ if cross_attention_dim is not None and idx <= 1:
1455
+ forward_encoder_hidden_states = encoder_hidden_states
1456
+ forward_encoder_attention_mask = encoder_attention_mask
1457
+ elif cross_attention_dim is not None and idx > 1:
1458
+ forward_encoder_hidden_states = encoder_hidden_states_1
1459
+ forward_encoder_attention_mask = encoder_attention_mask_1
1460
+ else:
1461
+ forward_encoder_hidden_states = None
1462
+ forward_encoder_attention_mask = None
1463
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1464
+ hidden_states,
1465
+ attention_mask=attention_mask,
1466
+ encoder_hidden_states=forward_encoder_hidden_states,
1467
+ encoder_attention_mask=forward_encoder_attention_mask,
1468
+ return_dict=False,
1469
+ )[0]
1470
+
1471
+ if self.upsamplers is not None:
1472
+ for upsampler in self.upsamplers:
1473
+ hidden_states = upsampler(hidden_states, upsample_size)
1474
+
1475
+ return hidden_states
pythonProject/diffusers-main/build/lib/diffusers/pipelines/audioldm2/pipeline_audioldm2.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ ClapFeatureExtractor,
22
+ ClapModel,
23
+ GPT2LMHeadModel,
24
+ RobertaTokenizer,
25
+ RobertaTokenizerFast,
26
+ SpeechT5HifiGan,
27
+ T5EncoderModel,
28
+ T5Tokenizer,
29
+ T5TokenizerFast,
30
+ VitsModel,
31
+ VitsTokenizer,
32
+ )
33
+
34
+ from ...models import AutoencoderKL
35
+ from ...schedulers import KarrasDiffusionSchedulers
36
+ from ...utils import (
37
+ is_accelerate_available,
38
+ is_accelerate_version,
39
+ is_librosa_available,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+ from ...utils.import_utils import is_transformers_version
44
+ from ...utils.torch_utils import empty_device_cache, randn_tensor
45
+ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
46
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
47
+
48
+
49
+ if is_librosa_available():
50
+ import librosa
51
+
52
+
53
+ from ...utils import is_torch_xla_available
54
+
55
+
56
+ if is_torch_xla_available():
57
+ import torch_xla.core.xla_model as xm
58
+
59
+ XLA_AVAILABLE = True
60
+ else:
61
+ XLA_AVAILABLE = False
62
+
63
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
+
65
+
66
+ EXAMPLE_DOC_STRING = """
67
+ Examples:
68
+ ```py
69
+ >>> import scipy
70
+ >>> import torch
71
+ >>> from diffusers import AudioLDM2Pipeline
72
+
73
+ >>> repo_id = "cvssp/audioldm2"
74
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
75
+ >>> pipe = pipe.to("cuda")
76
+
77
+ >>> # define the prompts
78
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
79
+ >>> negative_prompt = "Low quality."
80
+
81
+ >>> # set the seed for generator
82
+ >>> generator = torch.Generator("cuda").manual_seed(0)
83
+
84
+ >>> # run the generation
85
+ >>> audio = pipe(
86
+ ... prompt,
87
+ ... negative_prompt=negative_prompt,
88
+ ... num_inference_steps=200,
89
+ ... audio_length_in_s=10.0,
90
+ ... num_waveforms_per_prompt=3,
91
+ ... generator=generator,
92
+ ... ).audios
93
+
94
+ >>> # save the best audio sample (index 0) as a .wav file
95
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
96
+ ```
97
+ ```
98
+ #Using AudioLDM2 for Text To Speech
99
+ >>> import scipy
100
+ >>> import torch
101
+ >>> from diffusers import AudioLDM2Pipeline
102
+
103
+ >>> repo_id = "anhnct/audioldm2_gigaspeech"
104
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
105
+ >>> pipe = pipe.to("cuda")
106
+
107
+ >>> # define the prompts
108
+ >>> prompt = "A female reporter is speaking"
109
+ >>> transcript = "wish you have a good day"
110
+
111
+ >>> # set the seed for generator
112
+ >>> generator = torch.Generator("cuda").manual_seed(0)
113
+
114
+ >>> # run the generation
115
+ >>> audio = pipe(
116
+ ... prompt,
117
+ ... transcription=transcript,
118
+ ... num_inference_steps=200,
119
+ ... audio_length_in_s=10.0,
120
+ ... num_waveforms_per_prompt=2,
121
+ ... generator=generator,
122
+ ... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS
123
+ ... ).audios
124
+
125
+ >>> # save the best audio sample (index 0) as a .wav file
126
+ >>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])
127
+ ```
128
+ """
129
+
130
+
131
+ def prepare_inputs_for_generation(
132
+ inputs_embeds,
133
+ attention_mask=None,
134
+ past_key_values=None,
135
+ **kwargs,
136
+ ):
137
+ if past_key_values is not None:
138
+ # only last token for inputs_embeds if past is defined in kwargs
139
+ inputs_embeds = inputs_embeds[:, -1:]
140
+
141
+ return {
142
+ "inputs_embeds": inputs_embeds,
143
+ "attention_mask": attention_mask,
144
+ "past_key_values": past_key_values,
145
+ "use_cache": kwargs.get("use_cache"),
146
+ }
147
+
148
+
149
+ class AudioLDM2Pipeline(DiffusionPipeline):
150
+ r"""
151
+ Pipeline for text-to-audio generation using AudioLDM2.
152
+
153
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
154
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
155
+
156
+ Args:
157
+ vae ([`AutoencoderKL`]):
158
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
159
+ text_encoder ([`~transformers.ClapModel`]):
160
+ First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
161
+ [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
162
+ specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
163
+ text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
164
+ rank generated waveforms against the text prompt by computing similarity scores.
165
+ text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
166
+ Second frozen text-encoder. AudioLDM2 uses the encoder of
167
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
168
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use
169
+ for TTS. AudioLDM2 uses the encoder of
170
+ [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel).
171
+ projection_model ([`AudioLDM2ProjectionModel`]):
172
+ A trained model used to linearly project the hidden-states from the first and second text encoder models
173
+ and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
174
+ concatenated to give the input to the language model. A Learned Position Embedding for the Vits
175
+ hidden-states
176
+ language_model ([`~transformers.GPT2Model`]):
177
+ An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
178
+ outputs from the two text encoders.
179
+ tokenizer ([`~transformers.RobertaTokenizer`]):
180
+ Tokenizer to tokenize text for the first frozen text-encoder.
181
+ tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
182
+ Tokenizer to tokenize text for the second frozen text-encoder.
183
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
184
+ Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
185
+ unet ([`UNet2DConditionModel`]):
186
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
187
+ scheduler ([`SchedulerMixin`]):
188
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
189
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
190
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
191
+ Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ vae: AutoencoderKL,
197
+ text_encoder: ClapModel,
198
+ text_encoder_2: Union[T5EncoderModel, VitsModel],
199
+ projection_model: AudioLDM2ProjectionModel,
200
+ language_model: GPT2LMHeadModel,
201
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
202
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
203
+ feature_extractor: ClapFeatureExtractor,
204
+ unet: AudioLDM2UNet2DConditionModel,
205
+ scheduler: KarrasDiffusionSchedulers,
206
+ vocoder: SpeechT5HifiGan,
207
+ ):
208
+ super().__init__()
209
+
210
+ self.register_modules(
211
+ vae=vae,
212
+ text_encoder=text_encoder,
213
+ text_encoder_2=text_encoder_2,
214
+ projection_model=projection_model,
215
+ language_model=language_model,
216
+ tokenizer=tokenizer,
217
+ tokenizer_2=tokenizer_2,
218
+ feature_extractor=feature_extractor,
219
+ unet=unet,
220
+ scheduler=scheduler,
221
+ vocoder=vocoder,
222
+ )
223
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
224
+
225
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
226
+ def enable_vae_slicing(self):
227
+ r"""
228
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
229
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
230
+ """
231
+ self.vae.enable_slicing()
232
+
233
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
234
+ def disable_vae_slicing(self):
235
+ r"""
236
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
237
+ computing decoding in one step.
238
+ """
239
+ self.vae.disable_slicing()
240
+
241
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
242
+ r"""
243
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
244
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
245
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
246
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
247
+ """
248
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
249
+ from accelerate import cpu_offload_with_hook
250
+ else:
251
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
252
+
253
+ torch_device = torch.device(device)
254
+ device_index = torch_device.index
255
+
256
+ if gpu_id is not None and device_index is not None:
257
+ raise ValueError(
258
+ f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
259
+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
260
+ )
261
+
262
+ device_type = torch_device.type
263
+ device_str = device_type
264
+ if gpu_id or torch_device.index:
265
+ device_str = f"{device_str}:{gpu_id or torch_device.index}"
266
+ device = torch.device(device_str)
267
+
268
+ if self.device.type != "cpu":
269
+ self.to("cpu", silence_dtype_warnings=True)
270
+ empty_device_cache(device.type)
271
+
272
+ model_sequence = [
273
+ self.text_encoder.text_model,
274
+ self.text_encoder.text_projection,
275
+ self.text_encoder_2,
276
+ self.projection_model,
277
+ self.language_model,
278
+ self.unet,
279
+ self.vae,
280
+ self.vocoder,
281
+ self.text_encoder,
282
+ ]
283
+
284
+ hook = None
285
+ for cpu_offloaded_model in model_sequence:
286
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
287
+
288
+ # We'll offload the last model manually.
289
+ self.final_offload_hook = hook
290
+
291
+ def generate_language_model(
292
+ self,
293
+ inputs_embeds: torch.Tensor = None,
294
+ max_new_tokens: int = 8,
295
+ **model_kwargs,
296
+ ):
297
+ """
298
+
299
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
300
+
301
+ Parameters:
302
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
303
+ The sequence used as a prompt for the generation.
304
+ max_new_tokens (`int`):
305
+ Number of new tokens to generate.
306
+ model_kwargs (`Dict[str, Any]`, *optional*):
307
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
308
+ function of the model.
309
+
310
+ Return:
311
+ `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
312
+ The sequence of generated hidden-states.
313
+ """
314
+ cache_position_kwargs = {}
315
+ if is_transformers_version("<", "4.52.1"):
316
+ cache_position_kwargs["input_ids"] = inputs_embeds
317
+ else:
318
+ cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
319
+ cache_position_kwargs["device"] = (
320
+ self.language_model.device if getattr(self, "language_model", None) is not None else self.device
321
+ )
322
+ cache_position_kwargs["model_kwargs"] = model_kwargs
323
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
324
+ model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
325
+
326
+ for _ in range(max_new_tokens):
327
+ # prepare model inputs
328
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
329
+
330
+ # forward pass to get next hidden states
331
+ output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
332
+
333
+ next_hidden_states = output.hidden_states[-1]
334
+
335
+ # Update the model input
336
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
337
+
338
+ # Update generated hidden states, model inputs, and length for next step
339
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
340
+
341
+ return inputs_embeds[:, -max_new_tokens:, :]
342
+
343
+ def encode_prompt(
344
+ self,
345
+ prompt,
346
+ device,
347
+ num_waveforms_per_prompt,
348
+ do_classifier_free_guidance,
349
+ transcription=None,
350
+ negative_prompt=None,
351
+ prompt_embeds: Optional[torch.Tensor] = None,
352
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
353
+ generated_prompt_embeds: Optional[torch.Tensor] = None,
354
+ negative_generated_prompt_embeds: Optional[torch.Tensor] = None,
355
+ attention_mask: Optional[torch.LongTensor] = None,
356
+ negative_attention_mask: Optional[torch.LongTensor] = None,
357
+ max_new_tokens: Optional[int] = None,
358
+ ):
359
+ r"""
360
+ Encodes the prompt into text encoder hidden states.
361
+
362
+ Args:
363
+ prompt (`str` or `List[str]`, *optional*):
364
+ prompt to be encoded
365
+ transcription (`str` or `List[str]`):
366
+ transcription of text to speech
367
+ device (`torch.device`):
368
+ torch device
369
+ num_waveforms_per_prompt (`int`):
370
+ number of waveforms that should be generated per prompt
371
+ do_classifier_free_guidance (`bool`):
372
+ whether to use classifier free guidance or not
373
+ negative_prompt (`str` or `List[str]`, *optional*):
374
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
375
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
376
+ less than `1`).
377
+ prompt_embeds (`torch.Tensor`, *optional*):
378
+ Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
379
+ prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
380
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
381
+ Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
382
+ *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
383
+ `negative_prompt` input argument.
384
+ generated_prompt_embeds (`torch.Tensor`, *optional*):
385
+ Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
386
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
387
+ argument.
388
+ negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
389
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
390
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
391
+ `negative_prompt` input argument.
392
+ attention_mask (`torch.LongTensor`, *optional*):
393
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
394
+ be computed from `prompt` input argument.
395
+ negative_attention_mask (`torch.LongTensor`, *optional*):
396
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
397
+ mask will be computed from `negative_prompt` input argument.
398
+ max_new_tokens (`int`, *optional*, defaults to None):
399
+ The number of new tokens to generate with the GPT2 language model.
400
+ Returns:
401
+ prompt_embeds (`torch.Tensor`):
402
+ Text embeddings from the Flan T5 model.
403
+ attention_mask (`torch.LongTensor`):
404
+ Attention mask to be applied to the `prompt_embeds`.
405
+ generated_prompt_embeds (`torch.Tensor`):
406
+ Text embeddings generated from the GPT2 language model.
407
+
408
+ Example:
409
+
410
+ ```python
411
+ >>> import scipy
412
+ >>> import torch
413
+ >>> from diffusers import AudioLDM2Pipeline
414
+
415
+ >>> repo_id = "cvssp/audioldm2"
416
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
417
+ >>> pipe = pipe.to("cuda")
418
+
419
+ >>> # Get text embedding vectors
420
+ >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
421
+ ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
422
+ ... device="cuda",
423
+ ... do_classifier_free_guidance=True,
424
+ ... )
425
+
426
+ >>> # Pass text embeddings to pipeline for text-conditional audio generation
427
+ >>> audio = pipe(
428
+ ... prompt_embeds=prompt_embeds,
429
+ ... attention_mask=attention_mask,
430
+ ... generated_prompt_embeds=generated_prompt_embeds,
431
+ ... num_inference_steps=200,
432
+ ... audio_length_in_s=10.0,
433
+ ... ).audios[0]
434
+
435
+ >>> # save generated audio sample
436
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
437
+ ```"""
438
+ if prompt is not None and isinstance(prompt, str):
439
+ batch_size = 1
440
+ elif prompt is not None and isinstance(prompt, list):
441
+ batch_size = len(prompt)
442
+ else:
443
+ batch_size = prompt_embeds.shape[0]
444
+
445
+ # Define tokenizers and text encoders
446
+ tokenizers = [self.tokenizer, self.tokenizer_2]
447
+ is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel)
448
+
449
+ if is_vits_text_encoder:
450
+ text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder]
451
+ else:
452
+ text_encoders = [self.text_encoder, self.text_encoder_2]
453
+
454
+ if prompt_embeds is None:
455
+ prompt_embeds_list = []
456
+ attention_mask_list = []
457
+
458
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
459
+ use_prompt = isinstance(
460
+ tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast)
461
+ )
462
+ text_inputs = tokenizer(
463
+ prompt if use_prompt else transcription,
464
+ padding="max_length"
465
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
466
+ else True,
467
+ max_length=tokenizer.model_max_length,
468
+ truncation=True,
469
+ return_tensors="pt",
470
+ )
471
+ text_input_ids = text_inputs.input_ids
472
+ attention_mask = text_inputs.attention_mask
473
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
474
+
475
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
476
+ text_input_ids, untruncated_ids
477
+ ):
478
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
479
+ logger.warning(
480
+ f"The following part of your input was truncated because {text_encoder.config.model_type} can "
481
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
482
+ )
483
+
484
+ text_input_ids = text_input_ids.to(device)
485
+ attention_mask = attention_mask.to(device)
486
+
487
+ if text_encoder.config.model_type == "clap":
488
+ prompt_embeds = text_encoder.get_text_features(
489
+ text_input_ids,
490
+ attention_mask=attention_mask,
491
+ )
492
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
493
+ prompt_embeds = prompt_embeds[:, None, :]
494
+ # make sure that we attend to this single hidden-state
495
+ attention_mask = attention_mask.new_ones((batch_size, 1))
496
+ elif is_vits_text_encoder:
497
+ # Add end_token_id and attention mask in the end of sequence phonemes
498
+ for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask):
499
+ for idx, phoneme_id in enumerate(text_input_id):
500
+ if phoneme_id == 0:
501
+ text_input_id[idx] = 182
502
+ text_attention_mask[idx] = 1
503
+ break
504
+ prompt_embeds = text_encoder(
505
+ text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1)
506
+ )
507
+ prompt_embeds = prompt_embeds[0]
508
+ else:
509
+ prompt_embeds = text_encoder(
510
+ text_input_ids,
511
+ attention_mask=attention_mask,
512
+ )
513
+ prompt_embeds = prompt_embeds[0]
514
+
515
+ prompt_embeds_list.append(prompt_embeds)
516
+ attention_mask_list.append(attention_mask)
517
+
518
+ projection_output = self.projection_model(
519
+ hidden_states=prompt_embeds_list[0],
520
+ hidden_states_1=prompt_embeds_list[1],
521
+ attention_mask=attention_mask_list[0],
522
+ attention_mask_1=attention_mask_list[1],
523
+ )
524
+ projected_prompt_embeds = projection_output.hidden_states
525
+ projected_attention_mask = projection_output.attention_mask
526
+
527
+ generated_prompt_embeds = self.generate_language_model(
528
+ projected_prompt_embeds,
529
+ attention_mask=projected_attention_mask,
530
+ max_new_tokens=max_new_tokens,
531
+ )
532
+
533
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
534
+ attention_mask = (
535
+ attention_mask.to(device=device)
536
+ if attention_mask is not None
537
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
538
+ )
539
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
540
+
541
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
542
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
543
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
544
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
545
+
546
+ # duplicate attention mask for each generation per prompt
547
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
548
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
549
+
550
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
551
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
552
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
553
+ generated_prompt_embeds = generated_prompt_embeds.view(
554
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
555
+ )
556
+
557
+ # get unconditional embeddings for classifier free guidance
558
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
559
+ uncond_tokens: List[str]
560
+ if negative_prompt is None:
561
+ uncond_tokens = [""] * batch_size
562
+ elif type(prompt) is not type(negative_prompt):
563
+ raise TypeError(
564
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
565
+ f" {type(prompt)}."
566
+ )
567
+ elif isinstance(negative_prompt, str):
568
+ uncond_tokens = [negative_prompt]
569
+ elif batch_size != len(negative_prompt):
570
+ raise ValueError(
571
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
572
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
573
+ " the batch size of `prompt`."
574
+ )
575
+ else:
576
+ uncond_tokens = negative_prompt
577
+
578
+ negative_prompt_embeds_list = []
579
+ negative_attention_mask_list = []
580
+ max_length = prompt_embeds.shape[1]
581
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
582
+ uncond_input = tokenizer(
583
+ uncond_tokens,
584
+ padding="max_length",
585
+ max_length=tokenizer.model_max_length
586
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
587
+ else max_length,
588
+ truncation=True,
589
+ return_tensors="pt",
590
+ )
591
+
592
+ uncond_input_ids = uncond_input.input_ids.to(device)
593
+ negative_attention_mask = uncond_input.attention_mask.to(device)
594
+
595
+ if text_encoder.config.model_type == "clap":
596
+ negative_prompt_embeds = text_encoder.get_text_features(
597
+ uncond_input_ids,
598
+ attention_mask=negative_attention_mask,
599
+ )
600
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
601
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
602
+ # make sure that we attend to this single hidden-state
603
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
604
+ elif is_vits_text_encoder:
605
+ negative_prompt_embeds = torch.zeros(
606
+ batch_size,
607
+ tokenizer.model_max_length,
608
+ text_encoder.config.hidden_size,
609
+ ).to(dtype=self.text_encoder_2.dtype, device=device)
610
+ negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to(
611
+ dtype=self.text_encoder_2.dtype, device=device
612
+ )
613
+ else:
614
+ negative_prompt_embeds = text_encoder(
615
+ uncond_input_ids,
616
+ attention_mask=negative_attention_mask,
617
+ )
618
+ negative_prompt_embeds = negative_prompt_embeds[0]
619
+
620
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
621
+ negative_attention_mask_list.append(negative_attention_mask)
622
+
623
+ projection_output = self.projection_model(
624
+ hidden_states=negative_prompt_embeds_list[0],
625
+ hidden_states_1=negative_prompt_embeds_list[1],
626
+ attention_mask=negative_attention_mask_list[0],
627
+ attention_mask_1=negative_attention_mask_list[1],
628
+ )
629
+ negative_projected_prompt_embeds = projection_output.hidden_states
630
+ negative_projected_attention_mask = projection_output.attention_mask
631
+
632
+ negative_generated_prompt_embeds = self.generate_language_model(
633
+ negative_projected_prompt_embeds,
634
+ attention_mask=negative_projected_attention_mask,
635
+ max_new_tokens=max_new_tokens,
636
+ )
637
+
638
+ if do_classifier_free_guidance:
639
+ seq_len = negative_prompt_embeds.shape[1]
640
+
641
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
642
+ negative_attention_mask = (
643
+ negative_attention_mask.to(device=device)
644
+ if negative_attention_mask is not None
645
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
646
+ )
647
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
648
+ dtype=self.language_model.dtype, device=device
649
+ )
650
+
651
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
652
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
653
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
654
+
655
+ # duplicate unconditional attention mask for each generation per prompt
656
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
657
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
658
+
659
+ # duplicate unconditional generated embeddings for each generation per prompt
660
+ seq_len = negative_generated_prompt_embeds.shape[1]
661
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
662
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
663
+ batch_size * num_waveforms_per_prompt, seq_len, -1
664
+ )
665
+
666
+ # For classifier free guidance, we need to do two forward passes.
667
+ # Here we concatenate the unconditional and text embeddings into a single batch
668
+ # to avoid doing two forward passes
669
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
670
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
671
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
672
+
673
+ return prompt_embeds, attention_mask, generated_prompt_embeds
674
+
675
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
676
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
677
+ if mel_spectrogram.dim() == 4:
678
+ mel_spectrogram = mel_spectrogram.squeeze(1)
679
+
680
+ waveform = self.vocoder(mel_spectrogram)
681
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
682
+ waveform = waveform.cpu().float()
683
+ return waveform
684
+
685
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
686
+ if not is_librosa_available():
687
+ logger.info(
688
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
689
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
690
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
691
+ )
692
+ return audio
693
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
694
+ resampled_audio = librosa.resample(
695
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
696
+ )
697
+ inputs["input_features"] = self.feature_extractor(
698
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
699
+ ).input_features.type(dtype)
700
+ inputs = inputs.to(device)
701
+
702
+ # compute the audio-text similarity score using the CLAP model
703
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
704
+ # sort by the highest matching generations per prompt
705
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
706
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
707
+ return audio
708
+
709
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
710
+ def prepare_extra_step_kwargs(self, generator, eta):
711
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
712
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
713
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
714
+ # and should be between [0, 1]
715
+
716
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
717
+ extra_step_kwargs = {}
718
+ if accepts_eta:
719
+ extra_step_kwargs["eta"] = eta
720
+
721
+ # check if the scheduler accepts generator
722
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
723
+ if accepts_generator:
724
+ extra_step_kwargs["generator"] = generator
725
+ return extra_step_kwargs
726
+
727
+ def check_inputs(
728
+ self,
729
+ prompt,
730
+ audio_length_in_s,
731
+ vocoder_upsample_factor,
732
+ callback_steps,
733
+ transcription=None,
734
+ negative_prompt=None,
735
+ prompt_embeds=None,
736
+ negative_prompt_embeds=None,
737
+ generated_prompt_embeds=None,
738
+ negative_generated_prompt_embeds=None,
739
+ attention_mask=None,
740
+ negative_attention_mask=None,
741
+ ):
742
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
743
+ if audio_length_in_s < min_audio_length_in_s:
744
+ raise ValueError(
745
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
746
+ f"is {audio_length_in_s}."
747
+ )
748
+
749
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
750
+ raise ValueError(
751
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
752
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
753
+ f"{self.vae_scale_factor}."
754
+ )
755
+
756
+ if (callback_steps is None) or (
757
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
758
+ ):
759
+ raise ValueError(
760
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
761
+ f" {type(callback_steps)}."
762
+ )
763
+
764
+ if prompt is not None and prompt_embeds is not None:
765
+ raise ValueError(
766
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
767
+ " only forward one of the two."
768
+ )
769
+ elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
770
+ raise ValueError(
771
+ "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
772
+ "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
773
+ )
774
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
775
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
776
+
777
+ if negative_prompt is not None and negative_prompt_embeds is not None:
778
+ raise ValueError(
779
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
780
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
781
+ )
782
+ elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
783
+ raise ValueError(
784
+ "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
785
+ "both arguments are specified"
786
+ )
787
+
788
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
789
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
790
+ raise ValueError(
791
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
792
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
793
+ f" {negative_prompt_embeds.shape}."
794
+ )
795
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
796
+ raise ValueError(
797
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
798
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
799
+ )
800
+
801
+ if transcription is None:
802
+ if self.text_encoder_2.config.model_type == "vits":
803
+ raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
804
+ elif transcription is not None and (
805
+ not isinstance(transcription, str) and not isinstance(transcription, list)
806
+ ):
807
+ raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}")
808
+
809
+ if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
810
+ if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
811
+ raise ValueError(
812
+ "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
813
+ f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
814
+ f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
815
+ )
816
+ if (
817
+ negative_attention_mask is not None
818
+ and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
819
+ ):
820
+ raise ValueError(
821
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
822
+ f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
823
+ )
824
+
825
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
826
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
827
+ shape = (
828
+ batch_size,
829
+ num_channels_latents,
830
+ int(height) // self.vae_scale_factor,
831
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
832
+ )
833
+ if isinstance(generator, list) and len(generator) != batch_size:
834
+ raise ValueError(
835
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
836
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
837
+ )
838
+
839
+ if latents is None:
840
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
841
+ else:
842
+ latents = latents.to(device)
843
+
844
+ # scale the initial noise by the standard deviation required by the scheduler
845
+ latents = latents * self.scheduler.init_noise_sigma
846
+ return latents
847
+
848
+ @torch.no_grad()
849
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
850
+ def __call__(
851
+ self,
852
+ prompt: Union[str, List[str]] = None,
853
+ transcription: Union[str, List[str]] = None,
854
+ audio_length_in_s: Optional[float] = None,
855
+ num_inference_steps: int = 200,
856
+ guidance_scale: float = 3.5,
857
+ negative_prompt: Optional[Union[str, List[str]]] = None,
858
+ num_waveforms_per_prompt: Optional[int] = 1,
859
+ eta: float = 0.0,
860
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
861
+ latents: Optional[torch.Tensor] = None,
862
+ prompt_embeds: Optional[torch.Tensor] = None,
863
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
864
+ generated_prompt_embeds: Optional[torch.Tensor] = None,
865
+ negative_generated_prompt_embeds: Optional[torch.Tensor] = None,
866
+ attention_mask: Optional[torch.LongTensor] = None,
867
+ negative_attention_mask: Optional[torch.LongTensor] = None,
868
+ max_new_tokens: Optional[int] = None,
869
+ return_dict: bool = True,
870
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
871
+ callback_steps: Optional[int] = 1,
872
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
873
+ output_type: Optional[str] = "np",
874
+ ):
875
+ r"""
876
+ The call function to the pipeline for generation.
877
+
878
+ Args:
879
+ prompt (`str` or `List[str]`, *optional*):
880
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
881
+ transcription (`str` or `List[str]`, *optional*):\
882
+ The transcript for text to speech.
883
+ audio_length_in_s (`int`, *optional*, defaults to 10.24):
884
+ The length of the generated audio sample in seconds.
885
+ num_inference_steps (`int`, *optional*, defaults to 200):
886
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
887
+ expense of slower inference.
888
+ guidance_scale (`float`, *optional*, defaults to 3.5):
889
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
890
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
891
+ negative_prompt (`str` or `List[str]`, *optional*):
892
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
893
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
894
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
895
+ The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
896
+ scoring is performed between the generated outputs and the text prompt. This scoring ranks the
897
+ generated waveforms based on their cosine similarity with the text input in the joint text-audio
898
+ embedding space.
899
+ eta (`float`, *optional*, defaults to 0.0):
900
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
901
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
902
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
903
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
904
+ generation deterministic.
905
+ latents (`torch.Tensor`, *optional*):
906
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
907
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
908
+ tensor is generated by sampling using the supplied random `generator`.
909
+ prompt_embeds (`torch.Tensor`, *optional*):
910
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
911
+ provided, text embeddings are generated from the `prompt` input argument.
912
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
913
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
914
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
915
+ generated_prompt_embeds (`torch.Tensor`, *optional*):
916
+ Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
917
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
918
+ argument.
919
+ negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
920
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
921
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
922
+ `negative_prompt` input argument.
923
+ attention_mask (`torch.LongTensor`, *optional*):
924
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
925
+ be computed from `prompt` input argument.
926
+ negative_attention_mask (`torch.LongTensor`, *optional*):
927
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
928
+ mask will be computed from `negative_prompt` input argument.
929
+ max_new_tokens (`int`, *optional*, defaults to None):
930
+ Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
931
+ be taken from the config of the model.
932
+ return_dict (`bool`, *optional*, defaults to `True`):
933
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
934
+ plain tuple.
935
+ callback (`Callable`, *optional*):
936
+ A function that calls every `callback_steps` steps during inference. The function is called with the
937
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
938
+ callback_steps (`int`, *optional*, defaults to 1):
939
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
940
+ every step.
941
+ cross_attention_kwargs (`dict`, *optional*):
942
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
943
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
944
+ output_type (`str`, *optional*, defaults to `"np"`):
945
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
946
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
947
+ model (LDM) output.
948
+
949
+ Examples:
950
+
951
+ Returns:
952
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
953
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
954
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
955
+ """
956
+ # 0. Convert audio input length from seconds to spectrogram height
957
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
958
+
959
+ if audio_length_in_s is None:
960
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
961
+
962
+ height = int(audio_length_in_s / vocoder_upsample_factor)
963
+
964
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
965
+ if height % self.vae_scale_factor != 0:
966
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
967
+ logger.info(
968
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
969
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
970
+ f"denoising process."
971
+ )
972
+
973
+ # 1. Check inputs. Raise error if not correct
974
+ self.check_inputs(
975
+ prompt,
976
+ audio_length_in_s,
977
+ vocoder_upsample_factor,
978
+ callback_steps,
979
+ transcription,
980
+ negative_prompt,
981
+ prompt_embeds,
982
+ negative_prompt_embeds,
983
+ generated_prompt_embeds,
984
+ negative_generated_prompt_embeds,
985
+ attention_mask,
986
+ negative_attention_mask,
987
+ )
988
+
989
+ # 2. Define call parameters
990
+ if prompt is not None and isinstance(prompt, str):
991
+ batch_size = 1
992
+ elif prompt is not None and isinstance(prompt, list):
993
+ batch_size = len(prompt)
994
+ else:
995
+ batch_size = prompt_embeds.shape[0]
996
+
997
+ device = self._execution_device
998
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
1000
+ # corresponds to doing no classifier free guidance.
1001
+ do_classifier_free_guidance = guidance_scale > 1.0
1002
+
1003
+ # 3. Encode input prompt
1004
+ prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
1005
+ prompt,
1006
+ device,
1007
+ num_waveforms_per_prompt,
1008
+ do_classifier_free_guidance,
1009
+ transcription,
1010
+ negative_prompt,
1011
+ prompt_embeds=prompt_embeds,
1012
+ negative_prompt_embeds=negative_prompt_embeds,
1013
+ generated_prompt_embeds=generated_prompt_embeds,
1014
+ negative_generated_prompt_embeds=negative_generated_prompt_embeds,
1015
+ attention_mask=attention_mask,
1016
+ negative_attention_mask=negative_attention_mask,
1017
+ max_new_tokens=max_new_tokens,
1018
+ )
1019
+
1020
+ # 4. Prepare timesteps
1021
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1022
+ timesteps = self.scheduler.timesteps
1023
+
1024
+ # 5. Prepare latent variables
1025
+ num_channels_latents = self.unet.config.in_channels
1026
+ latents = self.prepare_latents(
1027
+ batch_size * num_waveforms_per_prompt,
1028
+ num_channels_latents,
1029
+ height,
1030
+ prompt_embeds.dtype,
1031
+ device,
1032
+ generator,
1033
+ latents,
1034
+ )
1035
+
1036
+ # 6. Prepare extra step kwargs
1037
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1038
+
1039
+ # 7. Denoising loop
1040
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1041
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1042
+ for i, t in enumerate(timesteps):
1043
+ # expand the latents if we are doing classifier free guidance
1044
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1045
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1046
+
1047
+ # predict the noise residual
1048
+ noise_pred = self.unet(
1049
+ latent_model_input,
1050
+ t,
1051
+ encoder_hidden_states=generated_prompt_embeds,
1052
+ encoder_hidden_states_1=prompt_embeds,
1053
+ encoder_attention_mask_1=attention_mask,
1054
+ return_dict=False,
1055
+ )[0]
1056
+
1057
+ # perform guidance
1058
+ if do_classifier_free_guidance:
1059
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1060
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1061
+
1062
+ # compute the previous noisy sample x_t -> x_t-1
1063
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1064
+
1065
+ # call the callback, if provided
1066
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1067
+ progress_bar.update()
1068
+ if callback is not None and i % callback_steps == 0:
1069
+ step_idx = i // getattr(self.scheduler, "order", 1)
1070
+ callback(step_idx, t, latents)
1071
+
1072
+ if XLA_AVAILABLE:
1073
+ xm.mark_step()
1074
+
1075
+ self.maybe_free_model_hooks()
1076
+
1077
+ # 8. Post-processing
1078
+ if not output_type == "latent":
1079
+ latents = 1 / self.vae.config.scaling_factor * latents
1080
+ mel_spectrogram = self.vae.decode(latents).sample
1081
+ else:
1082
+ return AudioPipelineOutput(audios=latents)
1083
+
1084
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1085
+
1086
+ audio = audio[:, :original_waveform_length]
1087
+
1088
+ # 9. Automatic scoring
1089
+ if num_waveforms_per_prompt > 1 and prompt is not None:
1090
+ audio = self.score_waveforms(
1091
+ text=prompt,
1092
+ audio=audio,
1093
+ num_waveforms_per_prompt=num_waveforms_per_prompt,
1094
+ device=device,
1095
+ dtype=prompt_embeds.dtype,
1096
+ )
1097
+
1098
+ if output_type == "np":
1099
+ audio = audio.numpy()
1100
+
1101
+ if not return_dict:
1102
+ return (audio,)
1103
+
1104
+ return AudioPipelineOutput(audios=audio)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_aura_flow import AuraFlowPipeline
36
+
37
+ else:
38
+ import sys
39
+
40
+ sys.modules[__name__] = _LazyModule(
41
+ __name__,
42
+ globals()["__file__"],
43
+ _import_structure,
44
+ module_spec=__spec__,
45
+ )
46
+
47
+ for name, value in _dummy_objects.items():
48
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/aura_flow/pipeline_aura_flow.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers import T5Tokenizer, UMT5EncoderModel
19
+
20
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
21
+ from ...image_processor import VaeImageProcessor
22
+ from ...loaders import AuraFlowLoraLoaderMixin
23
+ from ...models import AuraFlowTransformer2DModel, AutoencoderKL
24
+ from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
25
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
26
+ from ...utils import (
27
+ USE_PEFT_BACKEND,
28
+ is_torch_xla_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from ...utils.torch_utils import randn_tensor
35
+ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
36
+
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import AuraFlowPipeline
53
+
54
+ >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
55
+ >>> pipe = pipe.to("cuda")
56
+ >>> prompt = "A cat holding a sign that says hello world"
57
+ >>> image = pipe(prompt).images[0]
58
+ >>> image.save("aura_flow.png")
59
+ ```
60
+ """
61
+
62
+
63
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
64
+ def retrieve_timesteps(
65
+ scheduler,
66
+ num_inference_steps: Optional[int] = None,
67
+ device: Optional[Union[str, torch.device]] = None,
68
+ timesteps: Optional[List[int]] = None,
69
+ sigmas: Optional[List[float]] = None,
70
+ **kwargs,
71
+ ):
72
+ r"""
73
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
74
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
75
+
76
+ Args:
77
+ scheduler (`SchedulerMixin`):
78
+ The scheduler to get timesteps from.
79
+ num_inference_steps (`int`):
80
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
81
+ must be `None`.
82
+ device (`str` or `torch.device`, *optional*):
83
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
84
+ timesteps (`List[int]`, *optional*):
85
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
86
+ `num_inference_steps` and `sigmas` must be `None`.
87
+ sigmas (`List[float]`, *optional*):
88
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
89
+ `num_inference_steps` and `timesteps` must be `None`.
90
+
91
+ Returns:
92
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
93
+ second element is the number of inference steps.
94
+ """
95
+ if timesteps is not None and sigmas is not None:
96
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
97
+ if timesteps is not None:
98
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
99
+ if not accepts_timesteps:
100
+ raise ValueError(
101
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
102
+ f" timestep schedules. Please check whether you are using the correct scheduler."
103
+ )
104
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
105
+ timesteps = scheduler.timesteps
106
+ num_inference_steps = len(timesteps)
107
+ elif sigmas is not None:
108
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
109
+ if not accept_sigmas:
110
+ raise ValueError(
111
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
112
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
113
+ )
114
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
115
+ timesteps = scheduler.timesteps
116
+ num_inference_steps = len(timesteps)
117
+ else:
118
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ return timesteps, num_inference_steps
121
+
122
+
123
+ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
124
+ r"""
125
+ Args:
126
+ tokenizer (`T5TokenizerFast`):
127
+ Tokenizer of class
128
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
129
+ text_encoder ([`T5EncoderModel`]):
130
+ Frozen text-encoder. AuraFlow uses
131
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
132
+ [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
133
+ vae ([`AutoencoderKL`]):
134
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
135
+ transformer ([`AuraFlowTransformer2DModel`]):
136
+ Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
137
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
138
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
139
+ """
140
+
141
+ _optional_components = []
142
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
143
+ _callback_tensor_inputs = [
144
+ "latents",
145
+ "prompt_embeds",
146
+ ]
147
+
148
+ def __init__(
149
+ self,
150
+ tokenizer: T5Tokenizer,
151
+ text_encoder: UMT5EncoderModel,
152
+ vae: AutoencoderKL,
153
+ transformer: AuraFlowTransformer2DModel,
154
+ scheduler: FlowMatchEulerDiscreteScheduler,
155
+ ):
156
+ super().__init__()
157
+
158
+ self.register_modules(
159
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
160
+ )
161
+
162
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
163
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
164
+
165
+ def check_inputs(
166
+ self,
167
+ prompt,
168
+ height,
169
+ width,
170
+ negative_prompt,
171
+ prompt_embeds=None,
172
+ negative_prompt_embeds=None,
173
+ prompt_attention_mask=None,
174
+ negative_prompt_attention_mask=None,
175
+ callback_on_step_end_tensor_inputs=None,
176
+ ):
177
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
178
+ raise ValueError(
179
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
180
+ )
181
+
182
+ if callback_on_step_end_tensor_inputs is not None and not all(
183
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
184
+ ):
185
+ raise ValueError(
186
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
187
+ )
188
+ if prompt is not None and prompt_embeds is not None:
189
+ raise ValueError(
190
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
191
+ " only forward one of the two."
192
+ )
193
+ elif prompt is None and prompt_embeds is None:
194
+ raise ValueError(
195
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
196
+ )
197
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
198
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
199
+
200
+ if prompt is not None and negative_prompt_embeds is not None:
201
+ raise ValueError(
202
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
203
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
204
+ )
205
+
206
+ if negative_prompt is not None and negative_prompt_embeds is not None:
207
+ raise ValueError(
208
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
209
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
210
+ )
211
+
212
+ if prompt_embeds is not None and prompt_attention_mask is None:
213
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
214
+
215
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
216
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
217
+
218
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
219
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
220
+ raise ValueError(
221
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
222
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
223
+ f" {negative_prompt_embeds.shape}."
224
+ )
225
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
226
+ raise ValueError(
227
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
228
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
229
+ f" {negative_prompt_attention_mask.shape}."
230
+ )
231
+
232
+ def encode_prompt(
233
+ self,
234
+ prompt: Union[str, List[str]],
235
+ negative_prompt: Union[str, List[str]] = None,
236
+ do_classifier_free_guidance: bool = True,
237
+ num_images_per_prompt: int = 1,
238
+ device: Optional[torch.device] = None,
239
+ prompt_embeds: Optional[torch.Tensor] = None,
240
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
241
+ prompt_attention_mask: Optional[torch.Tensor] = None,
242
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
243
+ max_sequence_length: int = 256,
244
+ lora_scale: Optional[float] = None,
245
+ ):
246
+ r"""
247
+ Encodes the prompt into text encoder hidden states.
248
+
249
+ Args:
250
+ prompt (`str` or `List[str]`, *optional*):
251
+ prompt to be encoded
252
+ negative_prompt (`str` or `List[str]`, *optional*):
253
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
254
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
255
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
256
+ whether to use classifier free guidance or not
257
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
258
+ number of images that should be generated per prompt
259
+ device: (`torch.device`, *optional*):
260
+ torch device to place the resulting embeddings on
261
+ prompt_embeds (`torch.Tensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ prompt_attention_mask (`torch.Tensor`, *optional*):
265
+ Pre-generated attention mask for text embeddings.
266
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
267
+ Pre-generated negative text embeddings.
268
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
269
+ Pre-generated attention mask for negative text embeddings.
270
+ max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
271
+ lora_scale (`float`, *optional*):
272
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
273
+ """
274
+ # set lora scale so that monkey patched LoRA
275
+ # function of text encoder can correctly access it
276
+ if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
277
+ self._lora_scale = lora_scale
278
+
279
+ # dynamically adjust the LoRA scale
280
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
281
+ scale_lora_layers(self.text_encoder, lora_scale)
282
+
283
+ if device is None:
284
+ device = self._execution_device
285
+ if prompt is not None and isinstance(prompt, str):
286
+ batch_size = 1
287
+ elif prompt is not None and isinstance(prompt, list):
288
+ batch_size = len(prompt)
289
+ else:
290
+ batch_size = prompt_embeds.shape[0]
291
+
292
+ max_length = max_sequence_length
293
+ if prompt_embeds is None:
294
+ text_inputs = self.tokenizer(
295
+ prompt,
296
+ truncation=True,
297
+ max_length=max_length,
298
+ padding="max_length",
299
+ return_tensors="pt",
300
+ )
301
+ text_input_ids = text_inputs["input_ids"]
302
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
303
+
304
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
305
+ text_input_ids, untruncated_ids
306
+ ):
307
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
308
+ logger.warning(
309
+ "The following part of your input was truncated because T5 can only handle sequences up to"
310
+ f" {max_length} tokens: {removed_text}"
311
+ )
312
+
313
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
314
+ prompt_embeds = self.text_encoder(**text_inputs)[0]
315
+ prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
316
+ prompt_embeds = prompt_embeds * prompt_attention_mask
317
+
318
+ if self.text_encoder is not None:
319
+ dtype = self.text_encoder.dtype
320
+ elif self.transformer is not None:
321
+ dtype = self.transformer.dtype
322
+ else:
323
+ dtype = None
324
+
325
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
326
+
327
+ bs_embed, seq_len, _ = prompt_embeds.shape
328
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
329
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
330
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
331
+ prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1)
332
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
333
+
334
+ # get unconditional embeddings for classifier free guidance
335
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
336
+ negative_prompt = negative_prompt or ""
337
+ uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
338
+ max_length = prompt_embeds.shape[1]
339
+ uncond_input = self.tokenizer(
340
+ uncond_tokens,
341
+ truncation=True,
342
+ max_length=max_length,
343
+ padding="max_length",
344
+ return_tensors="pt",
345
+ )
346
+ uncond_input = {k: v.to(device) for k, v in uncond_input.items()}
347
+ negative_prompt_embeds = self.text_encoder(**uncond_input)[0]
348
+ negative_prompt_attention_mask = (
349
+ uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape)
350
+ )
351
+ negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask
352
+
353
+ if do_classifier_free_guidance:
354
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
355
+ seq_len = negative_prompt_embeds.shape[1]
356
+
357
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
358
+
359
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
360
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
361
+
362
+ negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1)
363
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
364
+ else:
365
+ negative_prompt_embeds = None
366
+ negative_prompt_attention_mask = None
367
+
368
+ if self.text_encoder is not None:
369
+ if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
370
+ # Retrieve the original scale by scaling back the LoRA layers
371
+ unscale_lora_layers(self.text_encoder, lora_scale)
372
+
373
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
374
+
375
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
376
+ def prepare_latents(
377
+ self,
378
+ batch_size,
379
+ num_channels_latents,
380
+ height,
381
+ width,
382
+ dtype,
383
+ device,
384
+ generator,
385
+ latents=None,
386
+ ):
387
+ if latents is not None:
388
+ return latents.to(device=device, dtype=dtype)
389
+
390
+ shape = (
391
+ batch_size,
392
+ num_channels_latents,
393
+ int(height) // self.vae_scale_factor,
394
+ int(width) // self.vae_scale_factor,
395
+ )
396
+
397
+ if isinstance(generator, list) and len(generator) != batch_size:
398
+ raise ValueError(
399
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
400
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
401
+ )
402
+
403
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
404
+
405
+ return latents
406
+
407
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
408
+ def upcast_vae(self):
409
+ dtype = self.vae.dtype
410
+ self.vae.to(dtype=torch.float32)
411
+ use_torch_2_0_or_xformers = isinstance(
412
+ self.vae.decoder.mid_block.attentions[0].processor,
413
+ (
414
+ AttnProcessor2_0,
415
+ XFormersAttnProcessor,
416
+ FusedAttnProcessor2_0,
417
+ ),
418
+ )
419
+ # if xformers or torch_2_0 is used attention block does not need
420
+ # to be in float32 which can save lots of memory
421
+ if use_torch_2_0_or_xformers:
422
+ self.vae.post_quant_conv.to(dtype)
423
+ self.vae.decoder.conv_in.to(dtype)
424
+ self.vae.decoder.mid_block.to(dtype)
425
+
426
+ @property
427
+ def guidance_scale(self):
428
+ return self._guidance_scale
429
+
430
+ @property
431
+ def attention_kwargs(self):
432
+ return self._attention_kwargs
433
+
434
+ @property
435
+ def num_timesteps(self):
436
+ return self._num_timesteps
437
+
438
+ @torch.no_grad()
439
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
440
+ def __call__(
441
+ self,
442
+ prompt: Union[str, List[str]] = None,
443
+ negative_prompt: Union[str, List[str]] = None,
444
+ num_inference_steps: int = 50,
445
+ sigmas: List[float] = None,
446
+ guidance_scale: float = 3.5,
447
+ num_images_per_prompt: Optional[int] = 1,
448
+ height: Optional[int] = 1024,
449
+ width: Optional[int] = 1024,
450
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
451
+ latents: Optional[torch.Tensor] = None,
452
+ prompt_embeds: Optional[torch.Tensor] = None,
453
+ prompt_attention_mask: Optional[torch.Tensor] = None,
454
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
455
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
456
+ max_sequence_length: int = 256,
457
+ output_type: Optional[str] = "pil",
458
+ return_dict: bool = True,
459
+ attention_kwargs: Optional[Dict[str, Any]] = None,
460
+ callback_on_step_end: Optional[
461
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
462
+ ] = None,
463
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
464
+ ) -> Union[ImagePipelineOutput, Tuple]:
465
+ r"""
466
+ Function invoked when calling the pipeline for generation.
467
+
468
+ Args:
469
+ prompt (`str` or `List[str]`, *optional*):
470
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
471
+ instead.
472
+ negative_prompt (`str` or `List[str]`, *optional*):
473
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
474
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
475
+ less than `1`).
476
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
477
+ The height in pixels of the generated image. This is set to 1024 by default for best results.
478
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
479
+ The width in pixels of the generated image. This is set to 1024 by default for best results.
480
+ num_inference_steps (`int`, *optional*, defaults to 50):
481
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
482
+ expense of slower inference.
483
+ sigmas (`List[float]`, *optional*):
484
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
485
+ `num_inference_steps` and `timesteps` must be `None`.
486
+ guidance_scale (`float`, *optional*, defaults to 5.0):
487
+ Guidance scale as defined in [Classifier-Free Diffusion
488
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
489
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
490
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
491
+ the text `prompt`, usually at the expense of lower image quality.
492
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
493
+ The number of images to generate per prompt.
494
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
495
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
496
+ to make generation deterministic.
497
+ latents (`torch.FloatTensor`, *optional*):
498
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
499
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
500
+ tensor will be generated by sampling using the supplied random `generator`.
501
+ prompt_embeds (`torch.FloatTensor`, *optional*):
502
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
503
+ provided, text embeddings will be generated from `prompt` input argument.
504
+ prompt_attention_mask (`torch.Tensor`, *optional*):
505
+ Pre-generated attention mask for text embeddings.
506
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
507
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
508
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
509
+ argument.
510
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
511
+ Pre-generated attention mask for negative text embeddings.
512
+ output_type (`str`, *optional*, defaults to `"pil"`):
513
+ The output format of the generate image. Choose between
514
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
515
+ return_dict (`bool`, *optional*, defaults to `True`):
516
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
517
+ of a plain tuple.
518
+ attention_kwargs (`dict`, *optional*):
519
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
520
+ `self.processor` in
521
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
522
+ callback_on_step_end (`Callable`, *optional*):
523
+ A function that calls at the end of each denoising steps during the inference. The function is called
524
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
525
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
526
+ `callback_on_step_end_tensor_inputs`.
527
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
528
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
529
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
530
+ `._callback_tensor_inputs` attribute of your pipeline class.
531
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
532
+
533
+ Examples:
534
+
535
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
536
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
537
+ where the first element is a list with the generated images.
538
+ """
539
+ # 1. Check inputs. Raise error if not correct
540
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
541
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
542
+
543
+ self.check_inputs(
544
+ prompt,
545
+ height,
546
+ width,
547
+ negative_prompt,
548
+ prompt_embeds,
549
+ negative_prompt_embeds,
550
+ prompt_attention_mask,
551
+ negative_prompt_attention_mask,
552
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
553
+ )
554
+
555
+ self._guidance_scale = guidance_scale
556
+ self._attention_kwargs = attention_kwargs
557
+
558
+ # 2. Determine batch size.
559
+ if prompt is not None and isinstance(prompt, str):
560
+ batch_size = 1
561
+ elif prompt is not None and isinstance(prompt, list):
562
+ batch_size = len(prompt)
563
+ else:
564
+ batch_size = prompt_embeds.shape[0]
565
+
566
+ device = self._execution_device
567
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
568
+
569
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
570
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
571
+ # corresponds to doing no classifier free guidance.
572
+ do_classifier_free_guidance = guidance_scale > 1.0
573
+
574
+ # 3. Encode input prompt
575
+ (
576
+ prompt_embeds,
577
+ prompt_attention_mask,
578
+ negative_prompt_embeds,
579
+ negative_prompt_attention_mask,
580
+ ) = self.encode_prompt(
581
+ prompt=prompt,
582
+ negative_prompt=negative_prompt,
583
+ do_classifier_free_guidance=do_classifier_free_guidance,
584
+ num_images_per_prompt=num_images_per_prompt,
585
+ device=device,
586
+ prompt_embeds=prompt_embeds,
587
+ negative_prompt_embeds=negative_prompt_embeds,
588
+ prompt_attention_mask=prompt_attention_mask,
589
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
590
+ max_sequence_length=max_sequence_length,
591
+ lora_scale=lora_scale,
592
+ )
593
+ if do_classifier_free_guidance:
594
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
595
+
596
+ # 4. Prepare timesteps
597
+
598
+ # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
599
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
600
+
601
+ # 5. Prepare latents.
602
+ latent_channels = self.transformer.config.in_channels
603
+ latents = self.prepare_latents(
604
+ batch_size * num_images_per_prompt,
605
+ latent_channels,
606
+ height,
607
+ width,
608
+ prompt_embeds.dtype,
609
+ device,
610
+ generator,
611
+ latents,
612
+ )
613
+
614
+ # 6. Denoising loop
615
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
616
+ self._num_timesteps = len(timesteps)
617
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
618
+ for i, t in enumerate(timesteps):
619
+ # expand the latents if we are doing classifier free guidance
620
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
621
+
622
+ # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
623
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
624
+ timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0])
625
+ timestep = timestep.to(latents.device, dtype=latents.dtype)
626
+
627
+ # predict noise model_output
628
+ noise_pred = self.transformer(
629
+ latent_model_input,
630
+ encoder_hidden_states=prompt_embeds,
631
+ timestep=timestep,
632
+ return_dict=False,
633
+ attention_kwargs=self.attention_kwargs,
634
+ )[0]
635
+
636
+ # perform guidance
637
+ if do_classifier_free_guidance:
638
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
639
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
640
+
641
+ # compute the previous noisy sample x_t -> x_t-1
642
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
643
+
644
+ if callback_on_step_end is not None:
645
+ callback_kwargs = {}
646
+ for k in callback_on_step_end_tensor_inputs:
647
+ callback_kwargs[k] = locals()[k]
648
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
649
+
650
+ latents = callback_outputs.pop("latents", latents)
651
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
652
+
653
+ # call the callback, if provided
654
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
655
+ progress_bar.update()
656
+
657
+ if XLA_AVAILABLE:
658
+ xm.mark_step()
659
+
660
+ if output_type == "latent":
661
+ image = latents
662
+ else:
663
+ # make sure the VAE is in float32 mode, as it overflows in float16
664
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
665
+ if needs_upcasting:
666
+ self.upcast_vae()
667
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
668
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
669
+ image = self.image_processor.postprocess(image, output_type=output_type)
670
+
671
+ # Offload all models
672
+ self.maybe_free_model_hooks()
673
+
674
+ if not return_dict:
675
+ return (image,)
676
+
677
+ return ImagePipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ from PIL import Image
7
+
8
+ from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
9
+
10
+
11
+ try:
12
+ if not (is_transformers_available() and is_torch_available()):
13
+ raise OptionalDependencyNotAvailable()
14
+ except OptionalDependencyNotAvailable:
15
+ from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
16
+ else:
17
+ from .blip_image_processing import BlipImageProcessor
18
+ from .modeling_blip2 import Blip2QFormerModel
19
+ from .modeling_ctx_clip import ContextCLIPTextModel
20
+ from .pipeline_blip_diffusion import BlipDiffusionPipeline
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/blip_image_processing.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for BLIP."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
23
+ from transformers.image_utils import (
24
+ OPENAI_CLIP_MEAN,
25
+ OPENAI_CLIP_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ )
35
+ from transformers.utils import TensorType, is_vision_available, logging
36
+
37
+ from diffusers.utils import numpy_to_pil
38
+
39
+
40
+ if is_vision_available():
41
+ import PIL.Image
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ # We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
48
+ # Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
49
+ class BlipImageProcessor(BaseImageProcessor):
50
+ r"""
51
+ Constructs a BLIP image processor.
52
+
53
+ Args:
54
+ do_resize (`bool`, *optional*, defaults to `True`):
55
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
56
+ `do_resize` parameter in the `preprocess` method.
57
+ size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
58
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
59
+ method.
60
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
61
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
62
+ overridden by the `resample` parameter in the `preprocess` method.
63
+ do_rescale (`bool`, *optional*, defaults to `True`):
64
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
65
+ `do_rescale` parameter in the `preprocess` method.
66
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
67
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
68
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
69
+ do_normalize (`bool`, *optional*, defaults to `True`):
70
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
71
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
72
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
73
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
74
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
75
+ overridden by the `image_mean` parameter in the `preprocess` method.
76
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
77
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
78
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
79
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
80
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
81
+ Whether to convert the image to RGB.
82
+ """
83
+
84
+ model_input_names = ["pixel_values"]
85
+
86
+ def __init__(
87
+ self,
88
+ do_resize: bool = True,
89
+ size: Dict[str, int] = None,
90
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
91
+ do_rescale: bool = True,
92
+ rescale_factor: Union[int, float] = 1 / 255,
93
+ do_normalize: bool = True,
94
+ image_mean: Optional[Union[float, List[float]]] = None,
95
+ image_std: Optional[Union[float, List[float]]] = None,
96
+ do_convert_rgb: bool = True,
97
+ do_center_crop: bool = True,
98
+ **kwargs,
99
+ ) -> None:
100
+ super().__init__(**kwargs)
101
+ size = size if size is not None else {"height": 224, "width": 224}
102
+ size = get_size_dict(size, default_to_square=True)
103
+
104
+ self.do_resize = do_resize
105
+ self.size = size
106
+ self.resample = resample
107
+ self.do_rescale = do_rescale
108
+ self.rescale_factor = rescale_factor
109
+ self.do_normalize = do_normalize
110
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
111
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
112
+ self.do_convert_rgb = do_convert_rgb
113
+ self.do_center_crop = do_center_crop
114
+
115
+ # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
116
+ def resize(
117
+ self,
118
+ image: np.ndarray,
119
+ size: Dict[str, int],
120
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
121
+ data_format: Optional[Union[str, ChannelDimension]] = None,
122
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
123
+ **kwargs,
124
+ ) -> np.ndarray:
125
+ """
126
+ Resize an image to `(size["height"], size["width"])`.
127
+
128
+ Args:
129
+ image (`np.ndarray`):
130
+ Image to resize.
131
+ size (`Dict[str, int]`):
132
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
133
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
134
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
135
+ data_format (`ChannelDimension` or `str`, *optional*):
136
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
137
+ image is used. Can be one of:
138
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
139
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
140
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
141
+ input_data_format (`ChannelDimension` or `str`, *optional*):
142
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
143
+ from the input image. Can be one of:
144
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
145
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
146
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
147
+
148
+ Returns:
149
+ `np.ndarray`: The resized image.
150
+ """
151
+ size = get_size_dict(size)
152
+ if "height" not in size or "width" not in size:
153
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
154
+ output_size = (size["height"], size["width"])
155
+ return resize(
156
+ image,
157
+ size=output_size,
158
+ resample=resample,
159
+ data_format=data_format,
160
+ input_data_format=input_data_format,
161
+ **kwargs,
162
+ )
163
+
164
+ def preprocess(
165
+ self,
166
+ images: ImageInput,
167
+ do_resize: Optional[bool] = None,
168
+ size: Optional[Dict[str, int]] = None,
169
+ resample: PILImageResampling = None,
170
+ do_rescale: Optional[bool] = None,
171
+ do_center_crop: Optional[bool] = None,
172
+ rescale_factor: Optional[float] = None,
173
+ do_normalize: Optional[bool] = None,
174
+ image_mean: Optional[Union[float, List[float]]] = None,
175
+ image_std: Optional[Union[float, List[float]]] = None,
176
+ return_tensors: Optional[Union[str, TensorType]] = None,
177
+ do_convert_rgb: bool = None,
178
+ data_format: ChannelDimension = ChannelDimension.FIRST,
179
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
180
+ **kwargs,
181
+ ) -> PIL.Image.Image:
182
+ """
183
+ Preprocess an image or batch of images.
184
+
185
+ Args:
186
+ images (`ImageInput`):
187
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
188
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
189
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
190
+ Whether to resize the image.
191
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
192
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
193
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
194
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
195
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
196
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
197
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
198
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
199
+ Whether to rescale the image values between [0 - 1].
200
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
201
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
202
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
203
+ Whether to normalize the image.
204
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
205
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
206
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
207
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
208
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
209
+ Whether to convert the image to RGB.
210
+ return_tensors (`str` or `TensorType`, *optional*):
211
+ The type of tensors to return. Can be one of:
212
+ - Unset: Return a list of `np.ndarray`.
213
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
214
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
215
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
216
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
217
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
218
+ The channel dimension format for the output image. Can be one of:
219
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
220
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
221
+ - Unset: Use the channel dimension format of the input image.
222
+ input_data_format (`ChannelDimension` or `str`, *optional*):
223
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
224
+ from the input image. Can be one of:
225
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
226
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
227
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
228
+ """
229
+ do_resize = do_resize if do_resize is not None else self.do_resize
230
+ resample = resample if resample is not None else self.resample
231
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
232
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
233
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
234
+ image_mean = image_mean if image_mean is not None else self.image_mean
235
+ image_std = image_std if image_std is not None else self.image_std
236
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
237
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
238
+
239
+ size = size if size is not None else self.size
240
+ size = get_size_dict(size, default_to_square=False)
241
+ images = make_list_of_images(images)
242
+
243
+ if not valid_images(images):
244
+ raise ValueError(
245
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
246
+ "torch.Tensor, tf.Tensor or jax.ndarray."
247
+ )
248
+
249
+ if do_resize and size is None or resample is None:
250
+ raise ValueError("Size and resample must be specified if do_resize is True.")
251
+
252
+ if do_rescale and rescale_factor is None:
253
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
254
+
255
+ if do_normalize and (image_mean is None or image_std is None):
256
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
257
+
258
+ # PIL RGBA images are converted to RGB
259
+ if do_convert_rgb:
260
+ images = [convert_to_rgb(image) for image in images]
261
+
262
+ # All transformations expect numpy arrays.
263
+ images = [to_numpy_array(image) for image in images]
264
+
265
+ if is_scaled_image(images[0]) and do_rescale:
266
+ logger.warning_once(
267
+ "It looks like you are trying to rescale already rescaled images. If the input"
268
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
269
+ )
270
+ if input_data_format is None:
271
+ # We assume that all images have the same channel dimension format.
272
+ input_data_format = infer_channel_dimension_format(images[0])
273
+
274
+ if do_resize:
275
+ images = [
276
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
277
+ for image in images
278
+ ]
279
+
280
+ if do_rescale:
281
+ images = [
282
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
283
+ for image in images
284
+ ]
285
+ if do_normalize:
286
+ images = [
287
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
288
+ for image in images
289
+ ]
290
+ if do_center_crop:
291
+ images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
292
+
293
+ images = [
294
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
295
+ ]
296
+
297
+ encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
298
+ return encoded_outputs
299
+
300
+ # Follows diffusers.VaeImageProcessor.postprocess
301
+ def postprocess(self, sample: torch.Tensor, output_type: str = "pil"):
302
+ if output_type not in ["pt", "np", "pil"]:
303
+ raise ValueError(
304
+ f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
305
+ )
306
+
307
+ # Equivalent to diffusers.VaeImageProcessor.denormalize
308
+ sample = (sample / 2 + 0.5).clamp(0, 1)
309
+ if output_type == "pt":
310
+ return sample
311
+
312
+ # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
313
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
314
+ if output_type == "np":
315
+ return sample
316
+ # Output_type must be 'pil'
317
+ sample = numpy_to_pil(sample)
318
+ return sample
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_blip2.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ from torch import nn
19
+ from transformers import BertTokenizer
20
+ from transformers.activations import QuickGELUActivation as QuickGELU
21
+ from transformers.modeling_outputs import (
22
+ BaseModelOutputWithPastAndCrossAttentions,
23
+ BaseModelOutputWithPooling,
24
+ BaseModelOutputWithPoolingAndCrossAttentions,
25
+ )
26
+ from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
27
+ from transformers.models.blip_2.modeling_blip_2 import (
28
+ Blip2Encoder,
29
+ Blip2PreTrainedModel,
30
+ Blip2QFormerAttention,
31
+ Blip2QFormerIntermediate,
32
+ Blip2QFormerOutput,
33
+ )
34
+ from transformers.pytorch_utils import apply_chunking_to_forward
35
+ from transformers.utils import (
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ # There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
45
+ # But it doesn't support getting multimodal embeddings. So, this module can be
46
+ # replaced with a future `transformers` version supports that.
47
+ class Blip2TextEmbeddings(nn.Module):
48
+ """Construct the embeddings from word and position embeddings."""
49
+
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
53
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
54
+
55
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
56
+ # any TensorFlow checkpoint file
57
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
58
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
59
+
60
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
61
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
62
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
63
+
64
+ self.config = config
65
+
66
+ def forward(
67
+ self,
68
+ input_ids=None,
69
+ position_ids=None,
70
+ query_embeds=None,
71
+ past_key_values_length=0,
72
+ ):
73
+ if input_ids is not None:
74
+ seq_length = input_ids.size()[1]
75
+ else:
76
+ seq_length = 0
77
+
78
+ if position_ids is None:
79
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
80
+
81
+ if input_ids is not None:
82
+ embeddings = self.word_embeddings(input_ids)
83
+ if self.position_embedding_type == "absolute":
84
+ position_embeddings = self.position_embeddings(position_ids)
85
+ embeddings = embeddings + position_embeddings
86
+
87
+ if query_embeds is not None:
88
+ batch_size = embeddings.shape[0]
89
+ # repeat the query embeddings for batch size
90
+ query_embeds = query_embeds.repeat(batch_size, 1, 1)
91
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
92
+ else:
93
+ embeddings = query_embeds
94
+ embeddings = embeddings.to(query_embeds.dtype)
95
+ embeddings = self.LayerNorm(embeddings)
96
+ embeddings = self.dropout(embeddings)
97
+ return embeddings
98
+
99
+
100
+ # Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
101
+ class Blip2VisionEmbeddings(nn.Module):
102
+ def __init__(self, config: Blip2VisionConfig):
103
+ super().__init__()
104
+ self.config = config
105
+ self.embed_dim = config.hidden_size
106
+ self.image_size = config.image_size
107
+ self.patch_size = config.patch_size
108
+
109
+ self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
110
+
111
+ self.patch_embedding = nn.Conv2d(
112
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
113
+ )
114
+
115
+ self.num_patches = (self.image_size // self.patch_size) ** 2
116
+ self.num_positions = self.num_patches + 1
117
+
118
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
119
+
120
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
121
+ batch_size = pixel_values.shape[0]
122
+ target_dtype = self.patch_embedding.weight.dtype
123
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
124
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
125
+
126
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
127
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
128
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
129
+ return embeddings
130
+
131
+
132
+ # The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
133
+ class Blip2QFormerEncoder(nn.Module):
134
+ def __init__(self, config):
135
+ super().__init__()
136
+ self.config = config
137
+ self.layer = nn.ModuleList(
138
+ [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
139
+ )
140
+ self.gradient_checkpointing = False
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states,
145
+ attention_mask=None,
146
+ head_mask=None,
147
+ encoder_hidden_states=None,
148
+ encoder_attention_mask=None,
149
+ past_key_values=None,
150
+ use_cache=None,
151
+ output_attentions=False,
152
+ output_hidden_states=False,
153
+ return_dict=True,
154
+ query_length=0,
155
+ ):
156
+ all_hidden_states = () if output_hidden_states else None
157
+ all_self_attentions = () if output_attentions else None
158
+ all_cross_attentions = () if output_attentions else None
159
+
160
+ next_decoder_cache = () if use_cache else None
161
+
162
+ for i in range(self.config.num_hidden_layers):
163
+ layer_module = self.layer[i]
164
+ if output_hidden_states:
165
+ all_hidden_states = all_hidden_states + (hidden_states,)
166
+
167
+ layer_head_mask = head_mask[i] if head_mask is not None else None
168
+ past_key_value = past_key_values[i] if past_key_values is not None else None
169
+
170
+ if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
171
+ if use_cache:
172
+ logger.warning(
173
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
174
+ )
175
+ use_cache = False
176
+
177
+ layer_outputs = self._gradient_checkpointing_func(
178
+ layer_module,
179
+ hidden_states,
180
+ attention_mask,
181
+ layer_head_mask,
182
+ encoder_hidden_states,
183
+ encoder_attention_mask,
184
+ past_key_value,
185
+ output_attentions,
186
+ query_length,
187
+ )
188
+ else:
189
+ layer_outputs = layer_module(
190
+ hidden_states,
191
+ attention_mask,
192
+ layer_head_mask,
193
+ encoder_hidden_states,
194
+ encoder_attention_mask,
195
+ past_key_value,
196
+ output_attentions,
197
+ query_length,
198
+ )
199
+
200
+ hidden_states = layer_outputs[0]
201
+ if use_cache:
202
+ next_decoder_cache += (layer_outputs[-1],)
203
+ if output_attentions:
204
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
205
+ if layer_module.has_cross_attention:
206
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
207
+
208
+ if output_hidden_states:
209
+ all_hidden_states = all_hidden_states + (hidden_states,)
210
+
211
+ if not return_dict:
212
+ return tuple(
213
+ v
214
+ for v in [
215
+ hidden_states,
216
+ next_decoder_cache,
217
+ all_hidden_states,
218
+ all_self_attentions,
219
+ all_cross_attentions,
220
+ ]
221
+ if v is not None
222
+ )
223
+ return BaseModelOutputWithPastAndCrossAttentions(
224
+ last_hidden_state=hidden_states,
225
+ past_key_values=next_decoder_cache,
226
+ hidden_states=all_hidden_states,
227
+ attentions=all_self_attentions,
228
+ cross_attentions=all_cross_attentions,
229
+ )
230
+
231
+
232
+ # The layers making up the Qformer encoder
233
+ class Blip2QFormerLayer(nn.Module):
234
+ def __init__(self, config, layer_idx):
235
+ super().__init__()
236
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
237
+ self.seq_len_dim = 1
238
+ self.attention = Blip2QFormerAttention(config)
239
+
240
+ self.layer_idx = layer_idx
241
+
242
+ if layer_idx % config.cross_attention_frequency == 0:
243
+ self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
244
+ self.has_cross_attention = True
245
+ else:
246
+ self.has_cross_attention = False
247
+
248
+ self.intermediate = Blip2QFormerIntermediate(config)
249
+ self.intermediate_query = Blip2QFormerIntermediate(config)
250
+ self.output_query = Blip2QFormerOutput(config)
251
+ self.output = Blip2QFormerOutput(config)
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states,
256
+ attention_mask=None,
257
+ head_mask=None,
258
+ encoder_hidden_states=None,
259
+ encoder_attention_mask=None,
260
+ past_key_value=None,
261
+ output_attentions=False,
262
+ query_length=0,
263
+ ):
264
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
265
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
266
+ self_attention_outputs = self.attention(
267
+ hidden_states,
268
+ attention_mask,
269
+ head_mask,
270
+ output_attentions=output_attentions,
271
+ past_key_value=self_attn_past_key_value,
272
+ )
273
+ attention_output = self_attention_outputs[0]
274
+ outputs = self_attention_outputs[1:-1]
275
+
276
+ present_key_value = self_attention_outputs[-1]
277
+
278
+ if query_length > 0:
279
+ query_attention_output = attention_output[:, :query_length, :]
280
+
281
+ if self.has_cross_attention:
282
+ if encoder_hidden_states is None:
283
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
284
+ cross_attention_outputs = self.crossattention(
285
+ query_attention_output,
286
+ attention_mask,
287
+ head_mask,
288
+ encoder_hidden_states,
289
+ encoder_attention_mask,
290
+ output_attentions=output_attentions,
291
+ )
292
+ query_attention_output = cross_attention_outputs[0]
293
+ # add cross attentions if we output attention weights
294
+ outputs = outputs + cross_attention_outputs[1:-1]
295
+
296
+ layer_output = apply_chunking_to_forward(
297
+ self.feed_forward_chunk_query,
298
+ self.chunk_size_feed_forward,
299
+ self.seq_len_dim,
300
+ query_attention_output,
301
+ )
302
+
303
+ if attention_output.shape[1] > query_length:
304
+ layer_output_text = apply_chunking_to_forward(
305
+ self.feed_forward_chunk,
306
+ self.chunk_size_feed_forward,
307
+ self.seq_len_dim,
308
+ attention_output[:, query_length:, :],
309
+ )
310
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
311
+ else:
312
+ layer_output = apply_chunking_to_forward(
313
+ self.feed_forward_chunk,
314
+ self.chunk_size_feed_forward,
315
+ self.seq_len_dim,
316
+ attention_output,
317
+ )
318
+ outputs = (layer_output,) + outputs
319
+
320
+ outputs = outputs + (present_key_value,)
321
+
322
+ return outputs
323
+
324
+ def feed_forward_chunk(self, attention_output):
325
+ intermediate_output = self.intermediate(attention_output)
326
+ layer_output = self.output(intermediate_output, attention_output)
327
+ return layer_output
328
+
329
+ def feed_forward_chunk_query(self, attention_output):
330
+ intermediate_output = self.intermediate_query(attention_output)
331
+ layer_output = self.output_query(intermediate_output, attention_output)
332
+ return layer_output
333
+
334
+
335
+ # ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
336
+ class ProjLayer(nn.Module):
337
+ def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
338
+ super().__init__()
339
+
340
+ # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
341
+ self.dense1 = nn.Linear(in_dim, hidden_dim)
342
+ self.act_fn = QuickGELU()
343
+ self.dense2 = nn.Linear(hidden_dim, out_dim)
344
+ self.dropout = nn.Dropout(drop_p)
345
+
346
+ self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
347
+
348
+ def forward(self, x):
349
+ x_in = x
350
+
351
+ x = self.LayerNorm(x)
352
+ x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
353
+
354
+ return x
355
+
356
+
357
+ # Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
358
+ class Blip2VisionModel(Blip2PreTrainedModel):
359
+ main_input_name = "pixel_values"
360
+ config_class = Blip2VisionConfig
361
+
362
+ def __init__(self, config: Blip2VisionConfig):
363
+ super().__init__(config)
364
+ self.config = config
365
+ embed_dim = config.hidden_size
366
+ self.embeddings = Blip2VisionEmbeddings(config)
367
+ self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
368
+ self.encoder = Blip2Encoder(config)
369
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
370
+
371
+ self.post_init()
372
+
373
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
374
+ def forward(
375
+ self,
376
+ pixel_values: Optional[torch.Tensor] = None,
377
+ output_attentions: Optional[bool] = None,
378
+ output_hidden_states: Optional[bool] = None,
379
+ return_dict: Optional[bool] = None,
380
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
381
+ r"""
382
+ Returns:
383
+
384
+ """
385
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ if pixel_values is None:
392
+ raise ValueError("You have to specify pixel_values")
393
+
394
+ hidden_states = self.embeddings(pixel_values)
395
+ hidden_states = self.pre_layernorm(hidden_states)
396
+ encoder_outputs = self.encoder(
397
+ inputs_embeds=hidden_states,
398
+ output_attentions=output_attentions,
399
+ output_hidden_states=output_hidden_states,
400
+ return_dict=return_dict,
401
+ )
402
+ last_hidden_state = encoder_outputs[0]
403
+ last_hidden_state = self.post_layernorm(last_hidden_state)
404
+
405
+ pooled_output = last_hidden_state[:, 0, :]
406
+ pooled_output = self.post_layernorm(pooled_output)
407
+
408
+ if not return_dict:
409
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
410
+
411
+ return BaseModelOutputWithPooling(
412
+ last_hidden_state=last_hidden_state,
413
+ pooler_output=pooled_output,
414
+ hidden_states=encoder_outputs.hidden_states,
415
+ attentions=encoder_outputs.attentions,
416
+ )
417
+
418
+ def get_input_embeddings(self):
419
+ return self.embeddings
420
+
421
+
422
+ # Qformer model, used to get multimodal embeddings from the text and image inputs
423
+ class Blip2QFormerModel(Blip2PreTrainedModel):
424
+ """
425
+ Querying Transformer (Q-Former), used in BLIP-2.
426
+ """
427
+
428
+ def __init__(self, config: Blip2Config):
429
+ super().__init__(config)
430
+ self.config = config
431
+ self.embeddings = Blip2TextEmbeddings(config.qformer_config)
432
+ self.visual_encoder = Blip2VisionModel(config.vision_config)
433
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
434
+ if not hasattr(config, "tokenizer") or config.tokenizer is None:
435
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
436
+ else:
437
+ self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
438
+ self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
439
+ self.proj_layer = ProjLayer(
440
+ in_dim=config.qformer_config.hidden_size,
441
+ out_dim=config.qformer_config.hidden_size,
442
+ hidden_dim=config.qformer_config.hidden_size * 4,
443
+ drop_p=0.1,
444
+ eps=1e-12,
445
+ )
446
+
447
+ self.encoder = Blip2QFormerEncoder(config.qformer_config)
448
+
449
+ self.post_init()
450
+
451
+ def get_input_embeddings(self):
452
+ return self.embeddings.word_embeddings
453
+
454
+ def set_input_embeddings(self, value):
455
+ self.embeddings.word_embeddings = value
456
+
457
+ def _prune_heads(self, heads_to_prune):
458
+ """
459
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
460
+ class PreTrainedModel
461
+ """
462
+ for layer, heads in heads_to_prune.items():
463
+ self.encoder.layer[layer].attention.prune_heads(heads)
464
+
465
+ def get_extended_attention_mask(
466
+ self,
467
+ attention_mask: torch.Tensor,
468
+ input_shape: Tuple[int],
469
+ device: torch.device,
470
+ has_query: bool = False,
471
+ ) -> torch.Tensor:
472
+ """
473
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
474
+
475
+ Arguments:
476
+ attention_mask (`torch.Tensor`):
477
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
478
+ input_shape (`Tuple[int]`):
479
+ The shape of the input to the model.
480
+ device (`torch.device`):
481
+ The device of the input to the model.
482
+
483
+ Returns:
484
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
485
+ """
486
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
487
+ # ourselves in which case we just need to make it broadcastable to all heads.
488
+ if attention_mask.dim() == 3:
489
+ extended_attention_mask = attention_mask[:, None, :, :]
490
+ elif attention_mask.dim() == 2:
491
+ # Provided a padding mask of dimensions [batch_size, seq_length]
492
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
493
+ extended_attention_mask = attention_mask[:, None, None, :]
494
+ else:
495
+ raise ValueError(
496
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
497
+ input_shape, attention_mask.shape
498
+ )
499
+ )
500
+
501
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
502
+ # masked positions, this operation will create a tensor which is 0.0 for
503
+ # positions we want to attend and -10000.0 for masked positions.
504
+ # Since we are adding it to the raw scores before the softmax, this is
505
+ # effectively the same as removing these entirely.
506
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
507
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
508
+ return extended_attention_mask
509
+
510
+ def forward(
511
+ self,
512
+ text_input=None,
513
+ image_input=None,
514
+ head_mask=None,
515
+ encoder_hidden_states=None,
516
+ encoder_attention_mask=None,
517
+ past_key_values=None,
518
+ use_cache=None,
519
+ output_attentions=None,
520
+ output_hidden_states=None,
521
+ return_dict=None,
522
+ ):
523
+ r"""
524
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
525
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
526
+ the model is configured as a decoder.
527
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`):
528
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
529
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
530
+ - 1 for tokens that are **not masked**,
531
+ - 0 for tokens that are **masked**.
532
+ past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of:
533
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
534
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
535
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
536
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
537
+ `(batch_size, sequence_length)`.
538
+ use_cache (`bool`, `optional`):
539
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
540
+ `past_key_values`).
541
+ """
542
+
543
+ text = self.tokenizer(text_input, return_tensors="pt", padding=True)
544
+ text = text.to(self.device)
545
+ input_ids = text.input_ids
546
+ batch_size = input_ids.shape[0]
547
+ query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
548
+ attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
549
+
550
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
551
+ output_hidden_states = (
552
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
553
+ )
554
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
555
+
556
+ # past_key_values_length
557
+ past_key_values_length = (
558
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
559
+ )
560
+
561
+ query_length = self.query_tokens.shape[1]
562
+
563
+ embedding_output = self.embeddings(
564
+ input_ids=input_ids,
565
+ query_embeds=self.query_tokens,
566
+ past_key_values_length=past_key_values_length,
567
+ )
568
+
569
+ # embedding_output = self.layernorm(query_embeds)
570
+ # embedding_output = self.dropout(embedding_output)
571
+
572
+ input_shape = embedding_output.size()[:-1]
573
+ batch_size, seq_length = input_shape
574
+ device = embedding_output.device
575
+
576
+ image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
577
+ # image_embeds_frozen = torch.ones_like(image_embeds_frozen)
578
+ encoder_hidden_states = image_embeds_frozen
579
+
580
+ if attention_mask is None:
581
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
582
+
583
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
584
+ # ourselves in which case we just need to make it broadcastable to all heads.
585
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
586
+
587
+ # If a 2D or 3D attention mask is provided for the cross-attention
588
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
589
+ if encoder_hidden_states is not None:
590
+ if isinstance(encoder_hidden_states, list):
591
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
592
+ else:
593
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
594
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
595
+
596
+ if isinstance(encoder_attention_mask, list):
597
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
598
+ elif encoder_attention_mask is None:
599
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
600
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
601
+ else:
602
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
603
+ else:
604
+ encoder_extended_attention_mask = None
605
+
606
+ # Prepare head mask if needed
607
+ # 1.0 in head_mask indicate we keep the head
608
+ # attention_probs has shape bsz x n_heads x N x N
609
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
610
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
611
+ head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
612
+
613
+ encoder_outputs = self.encoder(
614
+ embedding_output,
615
+ attention_mask=extended_attention_mask,
616
+ head_mask=head_mask,
617
+ encoder_hidden_states=encoder_hidden_states,
618
+ encoder_attention_mask=encoder_extended_attention_mask,
619
+ past_key_values=past_key_values,
620
+ use_cache=use_cache,
621
+ output_attentions=output_attentions,
622
+ output_hidden_states=output_hidden_states,
623
+ return_dict=return_dict,
624
+ query_length=query_length,
625
+ )
626
+ sequence_output = encoder_outputs[0]
627
+ pooled_output = sequence_output[:, 0, :]
628
+
629
+ if not return_dict:
630
+ return self.proj_layer(sequence_output[:, :query_length, :])
631
+
632
+ return BaseModelOutputWithPoolingAndCrossAttentions(
633
+ last_hidden_state=sequence_output,
634
+ pooler_output=pooled_output,
635
+ past_key_values=encoder_outputs.past_key_values,
636
+ hidden_states=encoder_outputs.hidden_states,
637
+ attentions=encoder_outputs.attentions,
638
+ cross_attentions=encoder_outputs.cross_attentions,
639
+ )
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Salesforce.com, inc.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import CLIPPreTrainedModel
20
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
21
+ from transformers.models.clip.configuration_clip import CLIPTextConfig
22
+ from transformers.models.clip.modeling_clip import CLIPEncoder
23
+
24
+
25
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
26
+ """
27
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
28
+ """
29
+ bsz, src_len = mask.size()
30
+ tgt_len = tgt_len if tgt_len is not None else src_len
31
+
32
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
33
+
34
+ inverted_mask = 1.0 - expanded_mask
35
+
36
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
37
+
38
+
39
+ # This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
40
+ # Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
41
+ # They pass through the clip model, along with the text embeddings, and interact with them using self attention
42
+ class ContextCLIPTextModel(CLIPPreTrainedModel):
43
+ config_class = CLIPTextConfig
44
+
45
+ _no_split_modules = ["CLIPEncoderLayer"]
46
+
47
+ def __init__(self, config: CLIPTextConfig):
48
+ super().__init__(config)
49
+ self.text_model = ContextCLIPTextTransformer(config)
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def forward(
54
+ self,
55
+ ctx_embeddings: torch.Tensor = None,
56
+ ctx_begin_pos: list = None,
57
+ input_ids: Optional[torch.Tensor] = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ position_ids: Optional[torch.Tensor] = None,
60
+ output_attentions: Optional[bool] = None,
61
+ output_hidden_states: Optional[bool] = None,
62
+ return_dict: Optional[bool] = None,
63
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
64
+ return self.text_model(
65
+ ctx_embeddings=ctx_embeddings,
66
+ ctx_begin_pos=ctx_begin_pos,
67
+ input_ids=input_ids,
68
+ attention_mask=attention_mask,
69
+ position_ids=position_ids,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict,
73
+ )
74
+
75
+
76
+ class ContextCLIPTextTransformer(nn.Module):
77
+ def __init__(self, config: CLIPTextConfig):
78
+ super().__init__()
79
+ self.config = config
80
+ embed_dim = config.hidden_size
81
+ self.embeddings = ContextCLIPTextEmbeddings(config)
82
+ self.encoder = CLIPEncoder(config)
83
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
84
+
85
+ def forward(
86
+ self,
87
+ ctx_embeddings: torch.Tensor,
88
+ ctx_begin_pos: list,
89
+ input_ids: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ position_ids: Optional[torch.Tensor] = None,
92
+ output_attentions: Optional[bool] = None,
93
+ output_hidden_states: Optional[bool] = None,
94
+ return_dict: Optional[bool] = None,
95
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
96
+ r"""
97
+ Returns:
98
+
99
+ """
100
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
101
+ output_hidden_states = (
102
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
103
+ )
104
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105
+
106
+ if input_ids is None:
107
+ raise ValueError("You have to specify either input_ids")
108
+
109
+ input_shape = input_ids.size()
110
+ input_ids = input_ids.view(-1, input_shape[-1])
111
+
112
+ hidden_states = self.embeddings(
113
+ input_ids=input_ids,
114
+ position_ids=position_ids,
115
+ ctx_embeddings=ctx_embeddings,
116
+ ctx_begin_pos=ctx_begin_pos,
117
+ )
118
+
119
+ bsz, seq_len = input_shape
120
+ if ctx_embeddings is not None:
121
+ seq_len += ctx_embeddings.size(1)
122
+ # CLIP's text model uses causal mask, prepare it here.
123
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
124
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
125
+ hidden_states.device
126
+ )
127
+ # expand attention_mask
128
+ if attention_mask is not None:
129
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
130
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
131
+
132
+ encoder_outputs = self.encoder(
133
+ inputs_embeds=hidden_states,
134
+ attention_mask=attention_mask,
135
+ causal_attention_mask=causal_attention_mask,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ last_hidden_state = encoder_outputs[0]
142
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
143
+
144
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
145
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
146
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
147
+ pooled_output = last_hidden_state[
148
+ torch.arange(last_hidden_state.shape[0], device=input_ids.device),
149
+ input_ids.to(torch.int).argmax(dim=-1),
150
+ ]
151
+
152
+ if not return_dict:
153
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
154
+
155
+ return BaseModelOutputWithPooling(
156
+ last_hidden_state=last_hidden_state,
157
+ pooler_output=pooled_output,
158
+ hidden_states=encoder_outputs.hidden_states,
159
+ attentions=encoder_outputs.attentions,
160
+ )
161
+
162
+ def _build_causal_attention_mask(self, bsz, seq_len, dtype):
163
+ # lazily create causal attention mask, with full attention between the vision tokens
164
+ # pytorch uses additive attention mask; fill with -inf
165
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
166
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
167
+ mask.triu_(1) # zero out the lower diagonal
168
+ mask = mask.unsqueeze(1) # expand mask
169
+ return mask
170
+
171
+
172
+ class ContextCLIPTextEmbeddings(nn.Module):
173
+ def __init__(self, config: CLIPTextConfig):
174
+ super().__init__()
175
+ embed_dim = config.hidden_size
176
+
177
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
178
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
179
+
180
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
181
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
182
+
183
+ def forward(
184
+ self,
185
+ ctx_embeddings: torch.Tensor,
186
+ ctx_begin_pos: list,
187
+ input_ids: Optional[torch.LongTensor] = None,
188
+ position_ids: Optional[torch.LongTensor] = None,
189
+ inputs_embeds: Optional[torch.Tensor] = None,
190
+ ) -> torch.Tensor:
191
+ if ctx_embeddings is None:
192
+ ctx_len = 0
193
+ else:
194
+ ctx_len = ctx_embeddings.shape[1]
195
+
196
+ seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
197
+
198
+ if position_ids is None:
199
+ position_ids = self.position_ids[:, :seq_length]
200
+
201
+ if inputs_embeds is None:
202
+ inputs_embeds = self.token_embedding(input_ids)
203
+
204
+ # for each input embeddings, add the ctx embeddings at the correct position
205
+ input_embeds_ctx = []
206
+ bsz = inputs_embeds.shape[0]
207
+
208
+ if ctx_embeddings is not None:
209
+ for i in range(bsz):
210
+ cbp = ctx_begin_pos[i]
211
+
212
+ prefix = inputs_embeds[i, :cbp]
213
+ # remove the special token embedding
214
+ suffix = inputs_embeds[i, cbp:]
215
+
216
+ input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
217
+
218
+ inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
219
+
220
+ position_embeddings = self.position_embedding(position_ids)
221
+ embeddings = inputs_embeds + position_embeddings
222
+
223
+ return embeddings
pythonProject/diffusers-main/build/lib/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Salesforce.com, inc.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.#
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Union
15
+
16
+ import PIL.Image
17
+ import torch
18
+ from transformers import CLIPTokenizer
19
+
20
+ from ...models import AutoencoderKL, UNet2DConditionModel
21
+ from ...schedulers import PNDMScheduler
22
+ from ...utils import (
23
+ is_torch_xla_available,
24
+ logging,
25
+ replace_example_docstring,
26
+ )
27
+ from ...utils.torch_utils import randn_tensor
28
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
29
+ from .blip_image_processing import BlipImageProcessor
30
+ from .modeling_blip2 import Blip2QFormerModel
31
+ from .modeling_ctx_clip import ContextCLIPTextModel
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> from diffusers.pipelines import BlipDiffusionPipeline
48
+ >>> from diffusers.utils import load_image
49
+ >>> import torch
50
+
51
+ >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
52
+ ... "Salesforce/blipdiffusion", torch_dtype=torch.float16
53
+ ... ).to("cuda")
54
+
55
+
56
+ >>> cond_subject = "dog"
57
+ >>> tgt_subject = "dog"
58
+ >>> text_prompt_input = "swimming underwater"
59
+
60
+ >>> cond_image = load_image(
61
+ ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
62
+ ... )
63
+ >>> guidance_scale = 7.5
64
+ >>> num_inference_steps = 25
65
+ >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
66
+
67
+
68
+ >>> output = blip_diffusion_pipe(
69
+ ... text_prompt_input,
70
+ ... cond_image,
71
+ ... cond_subject,
72
+ ... tgt_subject,
73
+ ... guidance_scale=guidance_scale,
74
+ ... num_inference_steps=num_inference_steps,
75
+ ... neg_prompt=negative_prompt,
76
+ ... height=512,
77
+ ... width=512,
78
+ ... ).images
79
+ >>> output[0].save("image.png")
80
+ ```
81
+ """
82
+
83
+
84
+ class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
85
+ """
86
+ Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
87
+
88
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
89
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
90
+
91
+ Args:
92
+ tokenizer ([`CLIPTokenizer`]):
93
+ Tokenizer for the text encoder
94
+ text_encoder ([`ContextCLIPTextModel`]):
95
+ Text encoder to encode the text prompt
96
+ vae ([`AutoencoderKL`]):
97
+ VAE model to map the latents to the image
98
+ unet ([`UNet2DConditionModel`]):
99
+ Conditional U-Net architecture to denoise the image embedding.
100
+ scheduler ([`PNDMScheduler`]):
101
+ A scheduler to be used in combination with `unet` to generate image latents.
102
+ qformer ([`Blip2QFormerModel`]):
103
+ QFormer model to get multi-modal embeddings from the text and image.
104
+ image_processor ([`BlipImageProcessor`]):
105
+ Image Processor to preprocess and postprocess the image.
106
+ ctx_begin_pos (int, `optional`, defaults to 2):
107
+ Position of the context token in the text encoder.
108
+ """
109
+
110
+ _last_supported_version = "0.33.1"
111
+ model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
112
+
113
+ def __init__(
114
+ self,
115
+ tokenizer: CLIPTokenizer,
116
+ text_encoder: ContextCLIPTextModel,
117
+ vae: AutoencoderKL,
118
+ unet: UNet2DConditionModel,
119
+ scheduler: PNDMScheduler,
120
+ qformer: Blip2QFormerModel,
121
+ image_processor: BlipImageProcessor,
122
+ ctx_begin_pos: int = 2,
123
+ mean: List[float] = None,
124
+ std: List[float] = None,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ tokenizer=tokenizer,
130
+ text_encoder=text_encoder,
131
+ vae=vae,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ qformer=qformer,
135
+ image_processor=image_processor,
136
+ )
137
+ self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
138
+
139
+ def get_query_embeddings(self, input_image, src_subject):
140
+ return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
141
+
142
+ # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
143
+ def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
144
+ rv = []
145
+ for prompt, tgt_subject in zip(prompts, tgt_subjects):
146
+ prompt = f"a {tgt_subject} {prompt.strip()}"
147
+ # a trick to amplify the prompt
148
+ rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
149
+
150
+ return rv
151
+
152
+ # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
153
+ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
154
+ shape = (batch_size, num_channels, height, width)
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ if latents is None:
162
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
163
+ else:
164
+ latents = latents.to(device=device, dtype=dtype)
165
+
166
+ # scale the initial noise by the standard deviation required by the scheduler
167
+ latents = latents * self.scheduler.init_noise_sigma
168
+ return latents
169
+
170
+ def encode_prompt(self, query_embeds, prompt, device=None):
171
+ device = device or self._execution_device
172
+
173
+ # embeddings for prompt, with query_embeds as context
174
+ max_len = self.text_encoder.text_model.config.max_position_embeddings
175
+ max_len -= self.qformer.config.num_query_tokens
176
+
177
+ tokenized_prompt = self.tokenizer(
178
+ prompt,
179
+ padding="max_length",
180
+ truncation=True,
181
+ max_length=max_len,
182
+ return_tensors="pt",
183
+ ).to(device)
184
+
185
+ batch_size = query_embeds.shape[0]
186
+ ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
187
+
188
+ text_embeddings = self.text_encoder(
189
+ input_ids=tokenized_prompt.input_ids,
190
+ ctx_embeddings=query_embeds,
191
+ ctx_begin_pos=ctx_begin_pos,
192
+ )[0]
193
+
194
+ return text_embeddings
195
+
196
+ @torch.no_grad()
197
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
198
+ def __call__(
199
+ self,
200
+ prompt: List[str],
201
+ reference_image: PIL.Image.Image,
202
+ source_subject_category: List[str],
203
+ target_subject_category: List[str],
204
+ latents: Optional[torch.Tensor] = None,
205
+ guidance_scale: float = 7.5,
206
+ height: int = 512,
207
+ width: int = 512,
208
+ num_inference_steps: int = 50,
209
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
210
+ neg_prompt: Optional[str] = "",
211
+ prompt_strength: float = 1.0,
212
+ prompt_reps: int = 20,
213
+ output_type: Optional[str] = "pil",
214
+ return_dict: bool = True,
215
+ ):
216
+ """
217
+ Function invoked when calling the pipeline for generation.
218
+
219
+ Args:
220
+ prompt (`List[str]`):
221
+ The prompt or prompts to guide the image generation.
222
+ reference_image (`PIL.Image.Image`):
223
+ The reference image to condition the generation on.
224
+ source_subject_category (`List[str]`):
225
+ The source subject category.
226
+ target_subject_category (`List[str]`):
227
+ The target subject category.
228
+ latents (`torch.Tensor`, *optional*):
229
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
230
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
231
+ tensor will be generated by random sampling.
232
+ guidance_scale (`float`, *optional*, defaults to 7.5):
233
+ Guidance scale as defined in [Classifier-Free Diffusion
234
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
235
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
236
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
237
+ the text `prompt`, usually at the expense of lower image quality.
238
+ height (`int`, *optional*, defaults to 512):
239
+ The height of the generated image.
240
+ width (`int`, *optional*, defaults to 512):
241
+ The width of the generated image.
242
+ num_inference_steps (`int`, *optional*, defaults to 50):
243
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
244
+ expense of slower inference.
245
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
246
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
247
+ to make generation deterministic.
248
+ neg_prompt (`str`, *optional*, defaults to ""):
249
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
250
+ if `guidance_scale` is less than `1`).
251
+ prompt_strength (`float`, *optional*, defaults to 1.0):
252
+ The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
253
+ to amplify the prompt.
254
+ prompt_reps (`int`, *optional*, defaults to 20):
255
+ The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
256
+ output_type (`str`, *optional*, defaults to `"pil"`):
257
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
258
+ (`np.array`) or `"pt"` (`torch.Tensor`).
259
+ return_dict (`bool`, *optional*, defaults to `True`):
260
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
261
+ Examples:
262
+
263
+ Returns:
264
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
265
+ """
266
+ device = self._execution_device
267
+
268
+ reference_image = self.image_processor.preprocess(
269
+ reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
270
+ )["pixel_values"]
271
+ reference_image = reference_image.to(device)
272
+
273
+ if isinstance(prompt, str):
274
+ prompt = [prompt]
275
+ if isinstance(source_subject_category, str):
276
+ source_subject_category = [source_subject_category]
277
+ if isinstance(target_subject_category, str):
278
+ target_subject_category = [target_subject_category]
279
+
280
+ batch_size = len(prompt)
281
+
282
+ prompt = self._build_prompt(
283
+ prompts=prompt,
284
+ tgt_subjects=target_subject_category,
285
+ prompt_strength=prompt_strength,
286
+ prompt_reps=prompt_reps,
287
+ )
288
+ query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
289
+ text_embeddings = self.encode_prompt(query_embeds, prompt, device)
290
+ do_classifier_free_guidance = guidance_scale > 1.0
291
+ if do_classifier_free_guidance:
292
+ max_length = self.text_encoder.text_model.config.max_position_embeddings
293
+
294
+ uncond_input = self.tokenizer(
295
+ [neg_prompt] * batch_size,
296
+ padding="max_length",
297
+ max_length=max_length,
298
+ return_tensors="pt",
299
+ )
300
+ uncond_embeddings = self.text_encoder(
301
+ input_ids=uncond_input.input_ids.to(device),
302
+ ctx_embeddings=None,
303
+ )[0]
304
+ # For classifier free guidance, we need to do two forward passes.
305
+ # Here we concatenate the unconditional and text embeddings into a single batch
306
+ # to avoid doing two forward passes
307
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
308
+
309
+ scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
310
+ latents = self.prepare_latents(
311
+ batch_size=batch_size,
312
+ num_channels=self.unet.config.in_channels,
313
+ height=height // scale_down_factor,
314
+ width=width // scale_down_factor,
315
+ generator=generator,
316
+ latents=latents,
317
+ dtype=self.unet.dtype,
318
+ device=device,
319
+ )
320
+ # set timesteps
321
+ extra_set_kwargs = {}
322
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
323
+
324
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
325
+ # expand the latents if we are doing classifier free guidance
326
+ do_classifier_free_guidance = guidance_scale > 1.0
327
+
328
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
329
+
330
+ noise_pred = self.unet(
331
+ latent_model_input,
332
+ timestep=t,
333
+ encoder_hidden_states=text_embeddings,
334
+ down_block_additional_residuals=None,
335
+ mid_block_additional_residual=None,
336
+ )["sample"]
337
+
338
+ # perform guidance
339
+ if do_classifier_free_guidance:
340
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
341
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
342
+
343
+ latents = self.scheduler.step(
344
+ noise_pred,
345
+ t,
346
+ latents,
347
+ )["prev_sample"]
348
+
349
+ if XLA_AVAILABLE:
350
+ xm.mark_step()
351
+
352
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
353
+ image = self.image_processor.postprocess(image, output_type=output_type)
354
+
355
+ # Offload all models
356
+ self.maybe_free_model_hooks()
357
+
358
+ if not return_dict:
359
+ return (image,)
360
+
361
+ return ImagePipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_bria"] = ["BriaPipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_bria import BriaPipeline
36
+
37
+ else:
38
+ import sys
39
+
40
+ sys.modules[__name__] = _LazyModule(
41
+ __name__,
42
+ globals()["__file__"],
43
+ _import_structure,
44
+ module_spec=__spec__,
45
+ )
46
+
47
+ for name, value in _dummy_objects.items():
48
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_bria.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from transformers import (
6
+ CLIPImageProcessor,
7
+ CLIPVisionModelWithProjection,
8
+ T5EncoderModel,
9
+ T5TokenizerFast,
10
+ )
11
+
12
+ from ...image_processor import VaeImageProcessor
13
+ from ...loaders import FluxLoraLoaderMixin
14
+ from ...models import AutoencoderKL
15
+ from ...models.transformers.transformer_bria import BriaTransformer2DModel
16
+ from ...pipelines import DiffusionPipeline
17
+ from ...pipelines.bria.pipeline_output import BriaPipelineOutput
18
+ from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
19
+ from ...schedulers import (
20
+ DDIMScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ FlowMatchEulerDiscreteScheduler,
23
+ KarrasDiffusionSchedulers,
24
+ )
25
+ from ...utils import (
26
+ USE_PEFT_BACKEND,
27
+ is_torch_xla_available,
28
+ logging,
29
+ replace_example_docstring,
30
+ scale_lora_layers,
31
+ unscale_lora_layers,
32
+ )
33
+ from ...utils.torch_utils import randn_tensor
34
+
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import BriaPipeline
51
+
52
+ >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
53
+ >>> pipe.to("cuda")
54
+ # BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
55
+
56
+ >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
57
+ >>> for block in pipe.text_encoder.encoder.block:
58
+ ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
59
+ # BRIA's VAE is not supported in mixed precision, so we use float32.
60
+
61
+ >>> if pipe.vae.config.shift_factor == 0:
62
+ ... pipe.vae.to(dtype=torch.float32)
63
+
64
+ >>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
65
+ >>> image = pipe(prompt).images[0]
66
+ >>> image.save("bria.png")
67
+ ```
68
+ """
69
+
70
+
71
+ def is_ng_none(negative_prompt):
72
+ return (
73
+ negative_prompt is None
74
+ or negative_prompt == ""
75
+ or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
76
+ or (type(negative_prompt) == list and negative_prompt[0] == "")
77
+ )
78
+
79
+
80
+ def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
81
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
82
+ sigmas = timesteps / num_train_timesteps
83
+
84
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
85
+ new_sigmas = sigmas[inds]
86
+ return new_sigmas
87
+
88
+
89
+ class BriaPipeline(DiffusionPipeline):
90
+ r"""
91
+ Based on FluxPipeline with several changes:
92
+ - no pooled embeddings
93
+ - We use zero padding for prompts
94
+ - No guidance embedding since this is not a distilled version
95
+
96
+ Args:
97
+ transformer ([`BriaTransformer2DModel`]):
98
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
99
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
100
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
101
+ vae ([`AutoencoderKL`]):
102
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
103
+ text_encoder ([`T5EncoderModel`]):
104
+ Frozen text-encoder. Bria uses
105
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
106
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
107
+ tokenizer (`T5TokenizerFast`):
108
+ Tokenizer of class
109
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
110
+ """
111
+
112
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
113
+ _optional_components = ["image_encoder", "feature_extractor"]
114
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
115
+
116
+ def __init__(
117
+ self,
118
+ transformer: BriaTransformer2DModel,
119
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
120
+ vae: AutoencoderKL,
121
+ text_encoder: T5EncoderModel,
122
+ tokenizer: T5TokenizerFast,
123
+ image_encoder: CLIPVisionModelWithProjection = None,
124
+ feature_extractor: CLIPImageProcessor = None,
125
+ ):
126
+ self.register_modules(
127
+ vae=vae,
128
+ text_encoder=text_encoder,
129
+ tokenizer=tokenizer,
130
+ transformer=transformer,
131
+ scheduler=scheduler,
132
+ image_encoder=image_encoder,
133
+ feature_extractor=feature_extractor,
134
+ )
135
+
136
+ self.vae_scale_factor = (
137
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
138
+ )
139
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
140
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
141
+
142
+ if self.vae.config.shift_factor is None:
143
+ self.vae.config.shift_factor = 0
144
+ self.vae.to(dtype=torch.float32)
145
+
146
+ def encode_prompt(
147
+ self,
148
+ prompt: Union[str, List[str]],
149
+ device: Optional[torch.device] = None,
150
+ num_images_per_prompt: int = 1,
151
+ do_classifier_free_guidance: bool = True,
152
+ negative_prompt: Optional[Union[str, List[str]]] = None,
153
+ prompt_embeds: Optional[torch.FloatTensor] = None,
154
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
155
+ max_sequence_length: int = 128,
156
+ lora_scale: Optional[float] = None,
157
+ ):
158
+ r"""
159
+
160
+ Args:
161
+ prompt (`str` or `List[str]`, *optional*):
162
+ prompt to be encoded
163
+ device: (`torch.device`):
164
+ torch device
165
+ num_images_per_prompt (`int`):
166
+ number of images that should be generated per prompt
167
+ do_classifier_free_guidance (`bool`):
168
+ whether to use classifier free guidance or not
169
+ negative_prompt (`str` or `List[str]`, *optional*):
170
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
171
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
172
+ less than `1`).
173
+ prompt_embeds (`torch.FloatTensor`, *optional*):
174
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
175
+ provided, text embeddings will be generated from `prompt` input argument.
176
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
177
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
178
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
179
+ argument.
180
+ """
181
+ device = device or self._execution_device
182
+
183
+ # set lora scale so that monkey patched LoRA
184
+ # function of text encoder can correctly access it
185
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
186
+ self._lora_scale = lora_scale
187
+
188
+ # dynamically adjust the LoRA scale
189
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
190
+ scale_lora_layers(self.text_encoder, lora_scale)
191
+
192
+ prompt = [prompt] if isinstance(prompt, str) else prompt
193
+ if prompt is not None:
194
+ batch_size = len(prompt)
195
+ else:
196
+ batch_size = prompt_embeds.shape[0]
197
+
198
+ if prompt_embeds is None:
199
+ prompt_embeds = self._get_t5_prompt_embeds(
200
+ prompt=prompt,
201
+ num_images_per_prompt=num_images_per_prompt,
202
+ max_sequence_length=max_sequence_length,
203
+ device=device,
204
+ )
205
+
206
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
207
+ if not is_ng_none(negative_prompt):
208
+ negative_prompt = (
209
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
210
+ )
211
+
212
+ if prompt is not None and type(prompt) is not type(negative_prompt):
213
+ raise TypeError(
214
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
215
+ f" {type(prompt)}."
216
+ )
217
+ elif batch_size != len(negative_prompt):
218
+ raise ValueError(
219
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
220
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
221
+ " the batch size of `prompt`."
222
+ )
223
+
224
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
225
+ prompt=negative_prompt,
226
+ num_images_per_prompt=num_images_per_prompt,
227
+ max_sequence_length=max_sequence_length,
228
+ device=device,
229
+ )
230
+ else:
231
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
232
+
233
+ if self.text_encoder is not None:
234
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
235
+ # Retrieve the original scale by scaling back the LoRA layers
236
+ unscale_lora_layers(self.text_encoder, lora_scale)
237
+
238
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device)
239
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
240
+
241
+ return prompt_embeds, negative_prompt_embeds, text_ids
242
+
243
+ @property
244
+ def guidance_scale(self):
245
+ return self._guidance_scale
246
+
247
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
248
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
249
+ # corresponds to doing no classifier free guidance.
250
+ @property
251
+ def do_classifier_free_guidance(self):
252
+ return self._guidance_scale > 1
253
+
254
+ @property
255
+ def attention_kwargs(self):
256
+ return self._attention_kwargs
257
+
258
+ @attention_kwargs.setter
259
+ def attention_kwargs(self, value):
260
+ self._attention_kwargs = value
261
+
262
+ @property
263
+ def num_timesteps(self):
264
+ return self._num_timesteps
265
+
266
+ @property
267
+ def interrupt(self):
268
+ return self._interrupt
269
+
270
+ def check_inputs(
271
+ self,
272
+ prompt,
273
+ height,
274
+ width,
275
+ negative_prompt=None,
276
+ prompt_embeds=None,
277
+ negative_prompt_embeds=None,
278
+ callback_on_step_end_tensor_inputs=None,
279
+ max_sequence_length=None,
280
+ ):
281
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
282
+ logger.warning(
283
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
284
+ )
285
+ if callback_on_step_end_tensor_inputs is not None and not all(
286
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
287
+ ):
288
+ raise ValueError(
289
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
290
+ )
291
+
292
+ if prompt is not None and prompt_embeds is not None:
293
+ raise ValueError(
294
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
295
+ " only forward one of the two."
296
+ )
297
+ elif prompt is None and prompt_embeds is None:
298
+ raise ValueError(
299
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
300
+ )
301
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
302
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
303
+
304
+ if negative_prompt is not None and negative_prompt_embeds is not None:
305
+ raise ValueError(
306
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
307
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
308
+ )
309
+
310
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
311
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
312
+ raise ValueError(
313
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
314
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
315
+ f" {negative_prompt_embeds.shape}."
316
+ )
317
+
318
+ if max_sequence_length is not None and max_sequence_length > 512:
319
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
320
+
321
+ def _get_t5_prompt_embeds(
322
+ self,
323
+ prompt: Union[str, List[str]] = None,
324
+ num_images_per_prompt: int = 1,
325
+ max_sequence_length: int = 128,
326
+ device: Optional[torch.device] = None,
327
+ ):
328
+ tokenizer = self.tokenizer
329
+ text_encoder = self.text_encoder
330
+ device = device or text_encoder.device
331
+
332
+ prompt = [prompt] if isinstance(prompt, str) else prompt
333
+ batch_size = len(prompt)
334
+ prompt_embeds_list = []
335
+ for p in prompt:
336
+ text_inputs = tokenizer(
337
+ p,
338
+ # padding="max_length",
339
+ max_length=max_sequence_length,
340
+ truncation=True,
341
+ add_special_tokens=True,
342
+ return_tensors="pt",
343
+ )
344
+ text_input_ids = text_inputs.input_ids
345
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
346
+
347
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
348
+ text_input_ids, untruncated_ids
349
+ ):
350
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
351
+ logger.warning(
352
+ "The following part of your input was truncated because `max_sequence_length` is set to "
353
+ f" {max_sequence_length} tokens: {removed_text}"
354
+ )
355
+
356
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
357
+
358
+ # Concat zeros to max_sequence
359
+ b, seq_len, dim = prompt_embeds.shape
360
+ if seq_len < max_sequence_length:
361
+ padding = torch.zeros(
362
+ (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
363
+ )
364
+ prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
365
+ prompt_embeds_list.append(prompt_embeds)
366
+
367
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
368
+ prompt_embeds = prompt_embeds.to(device=device)
369
+
370
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
371
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
372
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
373
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
374
+ return prompt_embeds
375
+
376
+ def prepare_latents(
377
+ self,
378
+ batch_size,
379
+ num_channels_latents,
380
+ height,
381
+ width,
382
+ dtype,
383
+ device,
384
+ generator,
385
+ latents=None,
386
+ ):
387
+ # VAE applies 8x compression on images but we must also account for packing which requires
388
+ # latent height and width to be divisible by 2.
389
+ height = 2 * (int(height) // self.vae_scale_factor)
390
+ width = 2 * (int(width) // self.vae_scale_factor)
391
+
392
+ shape = (batch_size, num_channels_latents, height, width)
393
+
394
+ if latents is not None:
395
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
396
+ return latents.to(device=device, dtype=dtype), latent_image_ids
397
+
398
+ if isinstance(generator, list) and len(generator) != batch_size:
399
+ raise ValueError(
400
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
401
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
402
+ )
403
+
404
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
405
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
406
+
407
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
408
+
409
+ return latents, latent_image_ids
410
+
411
+ @staticmethod
412
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
413
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
414
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
415
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
416
+
417
+ return latents
418
+
419
+ @staticmethod
420
+ def _unpack_latents(latents, height, width, vae_scale_factor):
421
+ batch_size, num_patches, channels = latents.shape
422
+
423
+ height = height // vae_scale_factor
424
+ width = width // vae_scale_factor
425
+
426
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
427
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
428
+
429
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
430
+
431
+ return latents
432
+
433
+ @staticmethod
434
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
435
+ latent_image_ids = torch.zeros(height, width, 3)
436
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
437
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
438
+
439
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
440
+
441
+ latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
442
+ latent_image_ids = latent_image_ids.reshape(
443
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
444
+ )
445
+
446
+ return latent_image_ids.to(device=device, dtype=dtype)
447
+
448
+ @torch.no_grad()
449
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
450
+ def __call__(
451
+ self,
452
+ prompt: Union[str, List[str]] = None,
453
+ height: Optional[int] = None,
454
+ width: Optional[int] = None,
455
+ num_inference_steps: int = 30,
456
+ timesteps: List[int] = None,
457
+ guidance_scale: float = 5,
458
+ negative_prompt: Optional[Union[str, List[str]]] = None,
459
+ num_images_per_prompt: Optional[int] = 1,
460
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
461
+ latents: Optional[torch.FloatTensor] = None,
462
+ prompt_embeds: Optional[torch.FloatTensor] = None,
463
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
464
+ output_type: Optional[str] = "pil",
465
+ return_dict: bool = True,
466
+ attention_kwargs: Optional[Dict[str, Any]] = None,
467
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
468
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
469
+ max_sequence_length: int = 128,
470
+ clip_value: Union[None, float] = None,
471
+ normalize: bool = False,
472
+ ):
473
+ r"""
474
+ Function invoked when calling the pipeline for generation.
475
+
476
+ Args:
477
+ prompt (`str` or `List[str]`, *optional*):
478
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
479
+ instead.
480
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
481
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
482
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
483
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
484
+ num_inference_steps (`int`, *optional*, defaults to 50):
485
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
486
+ expense of slower inference.
487
+ timesteps (`List[int]`, *optional*):
488
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
489
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
490
+ passed will be used. Must be in descending order.
491
+ guidance_scale (`float`, *optional*, defaults to 5.0):
492
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
493
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
494
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
495
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
496
+ usually at the expense of lower image quality.
497
+ negative_prompt (`str` or `List[str]`, *optional*):
498
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
499
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
500
+ less than `1`).
501
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
502
+ The number of images to generate per prompt.
503
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
504
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
505
+ to make generation deterministic.
506
+ latents (`torch.FloatTensor`, *optional*):
507
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
508
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
509
+ tensor will be generated by sampling using the supplied random `generator`.
510
+ prompt_embeds (`torch.FloatTensor`, *optional*):
511
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
512
+ provided, text embeddings will be generated from `prompt` input argument.
513
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
514
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
515
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
516
+ argument.
517
+ output_type (`str`, *optional*, defaults to `"pil"`):
518
+ The output format of the generate image. Choose between
519
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
520
+ return_dict (`bool`, *optional*, defaults to `True`):
521
+ Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple.
522
+ attention_kwargs (`dict`, *optional*):
523
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
524
+ `self.processor` in
525
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
526
+ callback_on_step_end (`Callable`, *optional*):
527
+ A function that calls at the end of each denoising steps during the inference. The function is called
528
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
529
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
530
+ `callback_on_step_end_tensor_inputs`.
531
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
532
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
533
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
534
+ `._callback_tensor_inputs` attribute of your pipeline class.
535
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
536
+
537
+ Examples:
538
+
539
+ Returns:
540
+ [`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict`
541
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
542
+ images.
543
+ """
544
+
545
+ height = height or self.default_sample_size * self.vae_scale_factor
546
+ width = width or self.default_sample_size * self.vae_scale_factor
547
+
548
+ # 1. Check inputs. Raise error if not correct
549
+ self.check_inputs(
550
+ prompt=prompt,
551
+ height=height,
552
+ width=width,
553
+ prompt_embeds=prompt_embeds,
554
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
555
+ max_sequence_length=max_sequence_length,
556
+ )
557
+
558
+ self._guidance_scale = guidance_scale
559
+ self.attention_kwargs = attention_kwargs
560
+ self._interrupt = False
561
+
562
+ # 2. Define call parameters
563
+ if prompt is not None and isinstance(prompt, str):
564
+ batch_size = 1
565
+ elif prompt is not None and isinstance(prompt, list):
566
+ batch_size = len(prompt)
567
+ else:
568
+ batch_size = prompt_embeds.shape[0]
569
+
570
+ device = self._execution_device
571
+
572
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
573
+
574
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
575
+ prompt=prompt,
576
+ negative_prompt=negative_prompt,
577
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
578
+ prompt_embeds=prompt_embeds,
579
+ negative_prompt_embeds=negative_prompt_embeds,
580
+ device=device,
581
+ num_images_per_prompt=num_images_per_prompt,
582
+ max_sequence_length=max_sequence_length,
583
+ lora_scale=lora_scale,
584
+ )
585
+
586
+ if self.do_classifier_free_guidance:
587
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
588
+
589
+ # 5. Prepare latent variables
590
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
591
+ latents, latent_image_ids = self.prepare_latents(
592
+ batch_size * num_images_per_prompt,
593
+ num_channels_latents,
594
+ height,
595
+ width,
596
+ prompt_embeds.dtype,
597
+ device,
598
+ generator,
599
+ latents,
600
+ )
601
+
602
+ if (
603
+ isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
604
+ and self.scheduler.config["use_dynamic_shifting"]
605
+ ):
606
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
607
+ image_seq_len = latents.shape[1]
608
+
609
+ mu = calculate_shift(
610
+ image_seq_len,
611
+ self.scheduler.config.base_image_seq_len,
612
+ self.scheduler.config.max_image_seq_len,
613
+ self.scheduler.config.base_shift,
614
+ self.scheduler.config.max_shift,
615
+ )
616
+ timesteps, num_inference_steps = retrieve_timesteps(
617
+ self.scheduler,
618
+ num_inference_steps,
619
+ device,
620
+ timesteps,
621
+ sigmas,
622
+ mu=mu,
623
+ )
624
+ else:
625
+ # 4. Prepare timesteps
626
+ # Sample from training sigmas
627
+ if isinstance(self.scheduler, DDIMScheduler) or isinstance(
628
+ self.scheduler, EulerAncestralDiscreteScheduler
629
+ ):
630
+ timesteps, num_inference_steps = retrieve_timesteps(
631
+ self.scheduler, num_inference_steps, device, None, None
632
+ )
633
+ else:
634
+ sigmas = get_original_sigmas(
635
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
636
+ num_inference_steps=num_inference_steps,
637
+ )
638
+ timesteps, num_inference_steps = retrieve_timesteps(
639
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
640
+ )
641
+
642
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
643
+ self._num_timesteps = len(timesteps)
644
+
645
+ if len(latent_image_ids.shape) == 3:
646
+ latent_image_ids = latent_image_ids[0]
647
+ if len(text_ids.shape) == 3:
648
+ text_ids = text_ids[0]
649
+
650
+ # 6. Denoising loop
651
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
652
+ for i, t in enumerate(timesteps):
653
+ if self.interrupt:
654
+ continue
655
+
656
+ # expand the latents if we are doing classifier free guidance
657
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
658
+ if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
659
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
660
+
661
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
662
+ timestep = t.expand(latent_model_input.shape[0])
663
+
664
+ # This is predicts "v" from flow-matching or eps from diffusion
665
+ noise_pred = self.transformer(
666
+ hidden_states=latent_model_input,
667
+ timestep=timestep,
668
+ encoder_hidden_states=prompt_embeds,
669
+ attention_kwargs=self.attention_kwargs,
670
+ return_dict=False,
671
+ txt_ids=text_ids,
672
+ img_ids=latent_image_ids,
673
+ )[0]
674
+
675
+ # perform guidance
676
+ if self.do_classifier_free_guidance:
677
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
678
+ cfg_noise_pred_text = noise_pred_text.std()
679
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
680
+
681
+ if normalize:
682
+ noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
683
+
684
+ if clip_value:
685
+ assert clip_value > 0
686
+ noise_pred = noise_pred.clip(-clip_value, clip_value)
687
+
688
+ # compute the previous noisy sample x_t -> x_t-1
689
+ latents_dtype = latents.dtype
690
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
691
+
692
+ if latents.dtype != latents_dtype:
693
+ if torch.backends.mps.is_available():
694
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
695
+ latents = latents.to(latents_dtype)
696
+
697
+ if callback_on_step_end is not None:
698
+ callback_kwargs = {}
699
+ for k in callback_on_step_end_tensor_inputs:
700
+ callback_kwargs[k] = locals()[k]
701
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
702
+
703
+ latents = callback_outputs.pop("latents", latents)
704
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
705
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
706
+
707
+ # call the callback, if provided
708
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
709
+ progress_bar.update()
710
+
711
+ if XLA_AVAILABLE:
712
+ xm.mark_step()
713
+
714
+ if output_type == "latent":
715
+ image = latents
716
+
717
+ else:
718
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
719
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
720
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
721
+ image = self.image_processor.postprocess(image, output_type=output_type)
722
+
723
+ # Offload all models
724
+ self.maybe_free_model_hooks()
725
+
726
+ if not return_dict:
727
+ return (image,)
728
+
729
+ return BriaPipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/bria/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class BriaPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Bria pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _additional_imports = {}
15
+ _import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_chroma"] = ["ChromaPipeline"]
26
+ _import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+ except OptionalDependencyNotAvailable:
32
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
33
+ else:
34
+ from .pipeline_chroma import ChromaPipeline
35
+ from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
36
+ else:
37
+ import sys
38
+
39
+ sys.modules[__name__] = _LazyModule(
40
+ __name__,
41
+ globals()["__file__"],
42
+ _import_structure,
43
+ module_spec=__spec__,
44
+ )
45
+
46
+ for name, value in _dummy_objects.items():
47
+ setattr(sys.modules[__name__], name, value)
48
+ for name, value in _additional_imports.items():
49
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
21
+
22
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
23
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
+ from ...models import AutoencoderKL, ChromaTransformer2DModel
25
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
26
+ from ...utils import (
27
+ USE_PEFT_BACKEND,
28
+ is_torch_xla_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from ...utils.torch_utils import randn_tensor
35
+ from ..pipeline_utils import DiffusionPipeline
36
+ from .pipeline_output import ChromaPipelineOutput
37
+
38
+
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> import torch
53
+ >>> from diffusers import ChromaPipeline
54
+
55
+ >>> model_id = "lodestones/Chroma"
56
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
57
+ >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
58
+ >>> pipe = ChromaPipeline.from_pretrained(
59
+ ... model_id,
60
+ ... transformer=transformer,
61
+ ... torch_dtype=torch.bfloat16,
62
+ ... )
63
+ >>> pipe.enable_model_cpu_offload()
64
+ >>> prompt = [
65
+ ... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
66
+ ... ]
67
+ >>> negative_prompt = [
68
+ ... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
69
+ ... ]
70
+ >>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
71
+ >>> image.save("chroma.png")
72
+ ```
73
+ """
74
+
75
+
76
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
77
+ def calculate_shift(
78
+ image_seq_len,
79
+ base_seq_len: int = 256,
80
+ max_seq_len: int = 4096,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ ):
84
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
85
+ b = base_shift - m * base_seq_len
86
+ mu = image_seq_len * m + b
87
+ return mu
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ r"""
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class ChromaPipeline(
151
+ DiffusionPipeline,
152
+ FluxLoraLoaderMixin,
153
+ FromSingleFileMixin,
154
+ TextualInversionLoaderMixin,
155
+ FluxIPAdapterMixin,
156
+ ):
157
+ r"""
158
+ The Chroma pipeline for text-to-image generation.
159
+
160
+ Reference: https://huggingface.co/lodestones/Chroma/
161
+
162
+ Args:
163
+ transformer ([`ChromaTransformer2DModel`]):
164
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
165
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
166
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
167
+ vae ([`AutoencoderKL`]):
168
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
169
+ text_encoder ([`T5EncoderModel`]):
170
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
171
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
172
+ tokenizer (`T5TokenizerFast`):
173
+ Second Tokenizer of class
174
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
175
+ """
176
+
177
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
178
+ _optional_components = ["image_encoder", "feature_extractor"]
179
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
180
+
181
+ def __init__(
182
+ self,
183
+ scheduler: FlowMatchEulerDiscreteScheduler,
184
+ vae: AutoencoderKL,
185
+ text_encoder: T5EncoderModel,
186
+ tokenizer: T5TokenizerFast,
187
+ transformer: ChromaTransformer2DModel,
188
+ image_encoder: CLIPVisionModelWithProjection = None,
189
+ feature_extractor: CLIPImageProcessor = None,
190
+ ):
191
+ super().__init__()
192
+
193
+ self.register_modules(
194
+ vae=vae,
195
+ text_encoder=text_encoder,
196
+ tokenizer=tokenizer,
197
+ transformer=transformer,
198
+ scheduler=scheduler,
199
+ image_encoder=image_encoder,
200
+ feature_extractor=feature_extractor,
201
+ )
202
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
203
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
204
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
205
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
206
+ self.default_sample_size = 128
207
+
208
+ def _get_t5_prompt_embeds(
209
+ self,
210
+ prompt: Union[str, List[str]] = None,
211
+ num_images_per_prompt: int = 1,
212
+ max_sequence_length: int = 512,
213
+ device: Optional[torch.device] = None,
214
+ dtype: Optional[torch.dtype] = None,
215
+ ):
216
+ device = device or self._execution_device
217
+ dtype = dtype or self.text_encoder.dtype
218
+
219
+ prompt = [prompt] if isinstance(prompt, str) else prompt
220
+ batch_size = len(prompt)
221
+
222
+ if isinstance(self, TextualInversionLoaderMixin):
223
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
224
+
225
+ text_inputs = self.tokenizer(
226
+ prompt,
227
+ padding="max_length",
228
+ max_length=max_sequence_length,
229
+ truncation=True,
230
+ return_length=False,
231
+ return_overflowing_tokens=False,
232
+ return_tensors="pt",
233
+ )
234
+ text_input_ids = text_inputs.input_ids
235
+ attention_mask = text_inputs.attention_mask.clone()
236
+
237
+ # Chroma requires the attention mask to include one padding token
238
+ seq_lengths = attention_mask.sum(dim=1)
239
+ mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
240
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
241
+
242
+ prompt_embeds = self.text_encoder(
243
+ text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
244
+ )[0]
245
+
246
+ dtype = self.text_encoder.dtype
247
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
248
+ attention_mask = attention_mask.to(dtype=dtype, device=device)
249
+
250
+ _, seq_len, _ = prompt_embeds.shape
251
+
252
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
253
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
254
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
255
+
256
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
257
+ attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
258
+
259
+ return prompt_embeds, attention_mask
260
+
261
+ def encode_prompt(
262
+ self,
263
+ prompt: Union[str, List[str]],
264
+ negative_prompt: Union[str, List[str]] = None,
265
+ device: Optional[torch.device] = None,
266
+ num_images_per_prompt: int = 1,
267
+ prompt_embeds: Optional[torch.Tensor] = None,
268
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
269
+ prompt_attention_mask: Optional[torch.Tensor] = None,
270
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
271
+ do_classifier_free_guidance: bool = True,
272
+ max_sequence_length: int = 512,
273
+ lora_scale: Optional[float] = None,
274
+ ):
275
+ r"""
276
+
277
+ Args:
278
+ prompt (`str` or `List[str]`, *optional*):
279
+ prompt to be encoded
280
+ negative_prompt (`str` or `List[str]`, *optional*):
281
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
282
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
283
+ device: (`torch.device`):
284
+ torch device
285
+ num_images_per_prompt (`int`):
286
+ number of images that should be generated per prompt
287
+ prompt_embeds (`torch.Tensor`, *optional*):
288
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
289
+ provided, text embeddings will be generated from `prompt` input argument.
290
+ lora_scale (`float`, *optional*):
291
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
292
+ """
293
+ device = device or self._execution_device
294
+
295
+ # set lora scale so that monkey patched LoRA
296
+ # function of text encoder can correctly access it
297
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
298
+ self._lora_scale = lora_scale
299
+
300
+ # dynamically adjust the LoRA scale
301
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
302
+ scale_lora_layers(self.text_encoder, lora_scale)
303
+
304
+ prompt = [prompt] if isinstance(prompt, str) else prompt
305
+
306
+ if prompt is not None:
307
+ batch_size = len(prompt)
308
+ else:
309
+ batch_size = prompt_embeds.shape[0]
310
+
311
+ if prompt_embeds is None:
312
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
313
+ prompt=prompt,
314
+ num_images_per_prompt=num_images_per_prompt,
315
+ max_sequence_length=max_sequence_length,
316
+ device=device,
317
+ )
318
+
319
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
320
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
321
+ negative_text_ids = None
322
+
323
+ if do_classifier_free_guidance:
324
+ if negative_prompt_embeds is None:
325
+ negative_prompt = negative_prompt or ""
326
+ negative_prompt = (
327
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
328
+ )
329
+
330
+ if prompt is not None and type(prompt) is not type(negative_prompt):
331
+ raise TypeError(
332
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
333
+ f" {type(prompt)}."
334
+ )
335
+ elif batch_size != len(negative_prompt):
336
+ raise ValueError(
337
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
338
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
339
+ " the batch size of `prompt`."
340
+ )
341
+
342
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
343
+ prompt=negative_prompt,
344
+ num_images_per_prompt=num_images_per_prompt,
345
+ max_sequence_length=max_sequence_length,
346
+ device=device,
347
+ )
348
+
349
+ negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
350
+
351
+ if self.text_encoder is not None:
352
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
353
+ # Retrieve the original scale by scaling back the LoRA layers
354
+ unscale_lora_layers(self.text_encoder, lora_scale)
355
+
356
+ return (
357
+ prompt_embeds,
358
+ text_ids,
359
+ prompt_attention_mask,
360
+ negative_prompt_embeds,
361
+ negative_text_ids,
362
+ negative_prompt_attention_mask,
363
+ )
364
+
365
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
366
+ def encode_image(self, image, device, num_images_per_prompt):
367
+ dtype = next(self.image_encoder.parameters()).dtype
368
+
369
+ if not isinstance(image, torch.Tensor):
370
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
371
+
372
+ image = image.to(device=device, dtype=dtype)
373
+ image_embeds = self.image_encoder(image).image_embeds
374
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
375
+ return image_embeds
376
+
377
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
378
+ def prepare_ip_adapter_image_embeds(
379
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
380
+ ):
381
+ image_embeds = []
382
+ if ip_adapter_image_embeds is None:
383
+ if not isinstance(ip_adapter_image, list):
384
+ ip_adapter_image = [ip_adapter_image]
385
+
386
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
387
+ raise ValueError(
388
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
389
+ )
390
+
391
+ for single_ip_adapter_image in ip_adapter_image:
392
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
393
+ image_embeds.append(single_image_embeds[None, :])
394
+ else:
395
+ if not isinstance(ip_adapter_image_embeds, list):
396
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
397
+
398
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
399
+ raise ValueError(
400
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
401
+ )
402
+
403
+ for single_image_embeds in ip_adapter_image_embeds:
404
+ image_embeds.append(single_image_embeds)
405
+
406
+ ip_adapter_image_embeds = []
407
+ for single_image_embeds in image_embeds:
408
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
409
+ single_image_embeds = single_image_embeds.to(device=device)
410
+ ip_adapter_image_embeds.append(single_image_embeds)
411
+
412
+ return ip_adapter_image_embeds
413
+
414
+ def check_inputs(
415
+ self,
416
+ prompt,
417
+ height,
418
+ width,
419
+ negative_prompt=None,
420
+ prompt_embeds=None,
421
+ prompt_attention_mask=None,
422
+ negative_prompt_embeds=None,
423
+ negative_prompt_attention_mask=None,
424
+ callback_on_step_end_tensor_inputs=None,
425
+ max_sequence_length=None,
426
+ ):
427
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
428
+ logger.warning(
429
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
430
+ )
431
+
432
+ if callback_on_step_end_tensor_inputs is not None and not all(
433
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
434
+ ):
435
+ raise ValueError(
436
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
437
+ )
438
+
439
+ if prompt is not None and prompt_embeds is not None:
440
+ raise ValueError(
441
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
442
+ " only forward one of the two."
443
+ )
444
+ elif prompt is None and prompt_embeds is None:
445
+ raise ValueError(
446
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
447
+ )
448
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
449
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
450
+
451
+ if negative_prompt is not None and negative_prompt_embeds is not None:
452
+ raise ValueError(
453
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
454
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
455
+ )
456
+
457
+ if prompt_embeds is not None and prompt_attention_mask is None:
458
+ raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
459
+
460
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
461
+ raise ValueError(
462
+ "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
463
+ )
464
+
465
+ if max_sequence_length is not None and max_sequence_length > 512:
466
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
467
+
468
+ @staticmethod
469
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
470
+ latent_image_ids = torch.zeros(height, width, 3)
471
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
472
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
473
+
474
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
475
+
476
+ latent_image_ids = latent_image_ids.reshape(
477
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
478
+ )
479
+
480
+ return latent_image_ids.to(device=device, dtype=dtype)
481
+
482
+ @staticmethod
483
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
484
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
485
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
486
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
487
+
488
+ return latents
489
+
490
+ @staticmethod
491
+ def _unpack_latents(latents, height, width, vae_scale_factor):
492
+ batch_size, num_patches, channels = latents.shape
493
+
494
+ # VAE applies 8x compression on images but we must also account for packing which requires
495
+ # latent height and width to be divisible by 2.
496
+ height = 2 * (int(height) // (vae_scale_factor * 2))
497
+ width = 2 * (int(width) // (vae_scale_factor * 2))
498
+
499
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
500
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
501
+
502
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
503
+
504
+ return latents
505
+
506
+ def enable_vae_slicing(self):
507
+ r"""
508
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
509
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
510
+ """
511
+ self.vae.enable_slicing()
512
+
513
+ def disable_vae_slicing(self):
514
+ r"""
515
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
516
+ computing decoding in one step.
517
+ """
518
+ self.vae.disable_slicing()
519
+
520
+ def enable_vae_tiling(self):
521
+ r"""
522
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
523
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
524
+ processing larger images.
525
+ """
526
+ self.vae.enable_tiling()
527
+
528
+ def disable_vae_tiling(self):
529
+ r"""
530
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
531
+ computing decoding in one step.
532
+ """
533
+ self.vae.disable_tiling()
534
+
535
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
536
+ def prepare_latents(
537
+ self,
538
+ batch_size,
539
+ num_channels_latents,
540
+ height,
541
+ width,
542
+ dtype,
543
+ device,
544
+ generator,
545
+ latents=None,
546
+ ):
547
+ # VAE applies 8x compression on images but we must also account for packing which requires
548
+ # latent height and width to be divisible by 2.
549
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
550
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
551
+
552
+ shape = (batch_size, num_channels_latents, height, width)
553
+
554
+ if latents is not None:
555
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
556
+ return latents.to(device=device, dtype=dtype), latent_image_ids
557
+
558
+ if isinstance(generator, list) and len(generator) != batch_size:
559
+ raise ValueError(
560
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
561
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
562
+ )
563
+
564
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
565
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
566
+
567
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
568
+
569
+ return latents, latent_image_ids
570
+
571
+ def _prepare_attention_mask(
572
+ self,
573
+ batch_size,
574
+ sequence_length,
575
+ dtype,
576
+ attention_mask=None,
577
+ ):
578
+ if attention_mask is None:
579
+ return attention_mask
580
+
581
+ # Extend the prompt attention mask to account for image tokens in the final sequence
582
+ attention_mask = torch.cat(
583
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
584
+ dim=1,
585
+ )
586
+ attention_mask = attention_mask.to(dtype)
587
+
588
+ return attention_mask
589
+
590
+ @property
591
+ def guidance_scale(self):
592
+ return self._guidance_scale
593
+
594
+ @property
595
+ def joint_attention_kwargs(self):
596
+ return self._joint_attention_kwargs
597
+
598
+ @property
599
+ def do_classifier_free_guidance(self):
600
+ return self._guidance_scale > 1
601
+
602
+ @property
603
+ def num_timesteps(self):
604
+ return self._num_timesteps
605
+
606
+ @property
607
+ def current_timestep(self):
608
+ return self._current_timestep
609
+
610
+ @property
611
+ def interrupt(self):
612
+ return self._interrupt
613
+
614
+ @torch.no_grad()
615
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
616
+ def __call__(
617
+ self,
618
+ prompt: Union[str, List[str]] = None,
619
+ negative_prompt: Union[str, List[str]] = None,
620
+ height: Optional[int] = None,
621
+ width: Optional[int] = None,
622
+ num_inference_steps: int = 35,
623
+ sigmas: Optional[List[float]] = None,
624
+ guidance_scale: float = 5.0,
625
+ num_images_per_prompt: Optional[int] = 1,
626
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
627
+ latents: Optional[torch.Tensor] = None,
628
+ prompt_embeds: Optional[torch.Tensor] = None,
629
+ ip_adapter_image: Optional[PipelineImageInput] = None,
630
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
631
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
632
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
633
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
634
+ prompt_attention_mask: Optional[torch.Tensor] = None,
635
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
636
+ output_type: Optional[str] = "pil",
637
+ return_dict: bool = True,
638
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
639
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
640
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
641
+ max_sequence_length: int = 512,
642
+ ):
643
+ r"""
644
+ Function invoked when calling the pipeline for generation.
645
+
646
+ Args:
647
+ prompt (`str` or `List[str]`, *optional*):
648
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
649
+ instead.
650
+ negative_prompt (`str` or `List[str]`, *optional*):
651
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
652
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
653
+ not greater than `1`).
654
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
655
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
656
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
657
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
658
+ num_inference_steps (`int`, *optional*, defaults to 50):
659
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
660
+ expense of slower inference.
661
+ sigmas (`List[float]`, *optional*):
662
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
663
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
664
+ will be used.
665
+ guidance_scale (`float`, *optional*, defaults to 3.5):
666
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
667
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
668
+
669
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
670
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
671
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
672
+ The number of images to generate per prompt.
673
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
674
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
675
+ to make generation deterministic.
676
+ latents (`torch.Tensor`, *optional*):
677
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
678
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
679
+ tensor will be generated by sampling using the supplied random `generator`.
680
+ prompt_embeds (`torch.Tensor`, *optional*):
681
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
682
+ provided, text embeddings will be generated from `prompt` input argument.
683
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
684
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
685
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
686
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
687
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
688
+ negative_ip_adapter_image:
689
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
690
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
691
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
692
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
693
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
694
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
695
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
696
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
697
+ argument.
698
+ prompt_attention_mask (torch.Tensor, *optional*):
699
+ Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
700
+ Chroma requires a single padding token remain unmasked. Please refer to
701
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
702
+ negative_prompt_attention_mask (torch.Tensor, *optional*):
703
+ Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
704
+ prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
705
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
706
+ output_type (`str`, *optional*, defaults to `"pil"`):
707
+ The output format of the generate image. Choose between
708
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
709
+ return_dict (`bool`, *optional*, defaults to `True`):
710
+ Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
711
+ joint_attention_kwargs (`dict`, *optional*):
712
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
713
+ `self.processor` in
714
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
715
+ callback_on_step_end (`Callable`, *optional*):
716
+ A function that calls at the end of each denoising steps during the inference. The function is called
717
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
718
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
719
+ `callback_on_step_end_tensor_inputs`.
720
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
721
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
722
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
723
+ `._callback_tensor_inputs` attribute of your pipeline class.
724
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
725
+
726
+ Examples:
727
+
728
+ Returns:
729
+ [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
730
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
731
+ generated images.
732
+ """
733
+
734
+ height = height or self.default_sample_size * self.vae_scale_factor
735
+ width = width or self.default_sample_size * self.vae_scale_factor
736
+
737
+ # 1. Check inputs. Raise error if not correct
738
+ self.check_inputs(
739
+ prompt,
740
+ height,
741
+ width,
742
+ negative_prompt=negative_prompt,
743
+ prompt_embeds=prompt_embeds,
744
+ prompt_attention_mask=prompt_attention_mask,
745
+ negative_prompt_embeds=negative_prompt_embeds,
746
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
747
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
748
+ max_sequence_length=max_sequence_length,
749
+ )
750
+
751
+ self._guidance_scale = guidance_scale
752
+ self._joint_attention_kwargs = joint_attention_kwargs
753
+ self._current_timestep = None
754
+ self._interrupt = False
755
+
756
+ # 2. Define call parameters
757
+ if prompt is not None and isinstance(prompt, str):
758
+ batch_size = 1
759
+ elif prompt is not None and isinstance(prompt, list):
760
+ batch_size = len(prompt)
761
+ else:
762
+ batch_size = prompt_embeds.shape[0]
763
+
764
+ device = self._execution_device
765
+
766
+ lora_scale = (
767
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
768
+ )
769
+ (
770
+ prompt_embeds,
771
+ text_ids,
772
+ prompt_attention_mask,
773
+ negative_prompt_embeds,
774
+ negative_text_ids,
775
+ negative_prompt_attention_mask,
776
+ ) = self.encode_prompt(
777
+ prompt=prompt,
778
+ negative_prompt=negative_prompt,
779
+ prompt_embeds=prompt_embeds,
780
+ negative_prompt_embeds=negative_prompt_embeds,
781
+ prompt_attention_mask=prompt_attention_mask,
782
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
783
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
784
+ device=device,
785
+ num_images_per_prompt=num_images_per_prompt,
786
+ max_sequence_length=max_sequence_length,
787
+ lora_scale=lora_scale,
788
+ )
789
+
790
+ # 4. Prepare latent variables
791
+ num_channels_latents = self.transformer.config.in_channels // 4
792
+ latents, latent_image_ids = self.prepare_latents(
793
+ batch_size * num_images_per_prompt,
794
+ num_channels_latents,
795
+ height,
796
+ width,
797
+ prompt_embeds.dtype,
798
+ device,
799
+ generator,
800
+ latents,
801
+ )
802
+
803
+ # 5. Prepare timesteps
804
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
805
+ image_seq_len = latents.shape[1]
806
+ mu = calculate_shift(
807
+ image_seq_len,
808
+ self.scheduler.config.get("base_image_seq_len", 256),
809
+ self.scheduler.config.get("max_image_seq_len", 4096),
810
+ self.scheduler.config.get("base_shift", 0.5),
811
+ self.scheduler.config.get("max_shift", 1.15),
812
+ )
813
+
814
+ attention_mask = self._prepare_attention_mask(
815
+ batch_size=latents.shape[0],
816
+ sequence_length=image_seq_len,
817
+ dtype=latents.dtype,
818
+ attention_mask=prompt_attention_mask,
819
+ )
820
+ negative_attention_mask = self._prepare_attention_mask(
821
+ batch_size=latents.shape[0],
822
+ sequence_length=image_seq_len,
823
+ dtype=latents.dtype,
824
+ attention_mask=negative_prompt_attention_mask,
825
+ )
826
+
827
+ timesteps, num_inference_steps = retrieve_timesteps(
828
+ self.scheduler,
829
+ num_inference_steps,
830
+ device,
831
+ sigmas=sigmas,
832
+ mu=mu,
833
+ )
834
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
835
+ self._num_timesteps = len(timesteps)
836
+
837
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
838
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
839
+ ):
840
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
841
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
842
+
843
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
844
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
845
+ ):
846
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
847
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
848
+
849
+ if self.joint_attention_kwargs is None:
850
+ self._joint_attention_kwargs = {}
851
+
852
+ image_embeds = None
853
+ negative_image_embeds = None
854
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
855
+ image_embeds = self.prepare_ip_adapter_image_embeds(
856
+ ip_adapter_image,
857
+ ip_adapter_image_embeds,
858
+ device,
859
+ batch_size * num_images_per_prompt,
860
+ )
861
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
862
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
863
+ negative_ip_adapter_image,
864
+ negative_ip_adapter_image_embeds,
865
+ device,
866
+ batch_size * num_images_per_prompt,
867
+ )
868
+
869
+ # 6. Denoising loop
870
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
871
+ for i, t in enumerate(timesteps):
872
+ if self.interrupt:
873
+ continue
874
+
875
+ self._current_timestep = t
876
+ if image_embeds is not None:
877
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
878
+
879
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
880
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
881
+
882
+ noise_pred = self.transformer(
883
+ hidden_states=latents,
884
+ timestep=timestep / 1000,
885
+ encoder_hidden_states=prompt_embeds,
886
+ txt_ids=text_ids,
887
+ img_ids=latent_image_ids,
888
+ attention_mask=attention_mask,
889
+ joint_attention_kwargs=self.joint_attention_kwargs,
890
+ return_dict=False,
891
+ )[0]
892
+
893
+ if self.do_classifier_free_guidance:
894
+ if negative_image_embeds is not None:
895
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
896
+ neg_noise_pred = self.transformer(
897
+ hidden_states=latents,
898
+ timestep=timestep / 1000,
899
+ encoder_hidden_states=negative_prompt_embeds,
900
+ txt_ids=negative_text_ids,
901
+ img_ids=latent_image_ids,
902
+ attention_mask=negative_attention_mask,
903
+ joint_attention_kwargs=self.joint_attention_kwargs,
904
+ return_dict=False,
905
+ )[0]
906
+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
907
+
908
+ # compute the previous noisy sample x_t -> x_t-1
909
+ latents_dtype = latents.dtype
910
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
911
+
912
+ if latents.dtype != latents_dtype:
913
+ if torch.backends.mps.is_available():
914
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
915
+ latents = latents.to(latents_dtype)
916
+
917
+ if callback_on_step_end is not None:
918
+ callback_kwargs = {}
919
+ for k in callback_on_step_end_tensor_inputs:
920
+ callback_kwargs[k] = locals()[k]
921
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
922
+
923
+ latents = callback_outputs.pop("latents", latents)
924
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
925
+
926
+ # call the callback, if provided
927
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
928
+ progress_bar.update()
929
+
930
+ if XLA_AVAILABLE:
931
+ xm.mark_step()
932
+
933
+ self._current_timestep = None
934
+
935
+ if output_type == "latent":
936
+ image = latents
937
+ else:
938
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
939
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
940
+ image = self.vae.decode(latents, return_dict=False)[0]
941
+ image = self.image_processor.postprocess(image, output_type=output_type)
942
+
943
+ # Offload all models
944
+ self.maybe_free_model_hooks()
945
+
946
+ if not return_dict:
947
+ return (image,)
948
+
949
+ return ChromaPipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_chroma_img2img.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
21
+
22
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
23
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
+ from ...models import AutoencoderKL, ChromaTransformer2DModel
25
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
26
+ from ...utils import (
27
+ USE_PEFT_BACKEND,
28
+ is_torch_xla_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from ...utils.torch_utils import randn_tensor
35
+ from ..pipeline_utils import DiffusionPipeline
36
+ from .pipeline_output import ChromaPipelineOutput
37
+
38
+
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> import torch
53
+ >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
54
+
55
+ >>> model_id = "lodestones/Chroma"
56
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
57
+ >>> pipe = ChromaImg2ImgPipeline.from_pretrained(
58
+ ... model_id,
59
+ ... transformer=transformer,
60
+ ... torch_dtype=torch.bfloat16,
61
+ ... )
62
+ >>> pipe.enable_model_cpu_offload()
63
+ >>> init_image = load_image(
64
+ ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
65
+ ... )
66
+ >>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
67
+ >>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
68
+ >>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
69
+ >>> image.save("chroma-img2img.png")
70
+ ```
71
+ """
72
+
73
+
74
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
75
+ def calculate_shift(
76
+ image_seq_len,
77
+ base_seq_len: int = 256,
78
+ max_seq_len: int = 4096,
79
+ base_shift: float = 0.5,
80
+ max_shift: float = 1.15,
81
+ ):
82
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
83
+ b = base_shift - m * base_seq_len
84
+ mu = image_seq_len * m + b
85
+ return mu
86
+
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
89
+ def retrieve_latents(
90
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
91
+ ):
92
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
93
+ return encoder_output.latent_dist.sample(generator)
94
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
95
+ return encoder_output.latent_dist.mode()
96
+ elif hasattr(encoder_output, "latents"):
97
+ return encoder_output.latents
98
+ else:
99
+ raise AttributeError("Could not access latents of provided encoder_output")
100
+
101
+
102
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
103
+ def retrieve_timesteps(
104
+ scheduler,
105
+ num_inference_steps: Optional[int] = None,
106
+ device: Optional[Union[str, torch.device]] = None,
107
+ timesteps: Optional[List[int]] = None,
108
+ sigmas: Optional[List[float]] = None,
109
+ **kwargs,
110
+ ):
111
+ r"""
112
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
113
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
114
+
115
+ Args:
116
+ scheduler (`SchedulerMixin`):
117
+ The scheduler to get timesteps from.
118
+ num_inference_steps (`int`):
119
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
120
+ must be `None`.
121
+ device (`str` or `torch.device`, *optional*):
122
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
123
+ timesteps (`List[int]`, *optional*):
124
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
125
+ `num_inference_steps` and `sigmas` must be `None`.
126
+ sigmas (`List[float]`, *optional*):
127
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
128
+ `num_inference_steps` and `timesteps` must be `None`.
129
+
130
+ Returns:
131
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
132
+ second element is the number of inference steps.
133
+ """
134
+ if timesteps is not None and sigmas is not None:
135
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
136
+ if timesteps is not None:
137
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
138
+ if not accepts_timesteps:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" timestep schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ elif sigmas is not None:
147
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
148
+ if not accept_sigmas:
149
+ raise ValueError(
150
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
151
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
152
+ )
153
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
154
+ timesteps = scheduler.timesteps
155
+ num_inference_steps = len(timesteps)
156
+ else:
157
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ return timesteps, num_inference_steps
160
+
161
+
162
+ class ChromaImg2ImgPipeline(
163
+ DiffusionPipeline,
164
+ FluxLoraLoaderMixin,
165
+ FromSingleFileMixin,
166
+ TextualInversionLoaderMixin,
167
+ FluxIPAdapterMixin,
168
+ ):
169
+ r"""
170
+ The Chroma pipeline for image-to-image generation.
171
+
172
+ Reference: https://huggingface.co/lodestones/Chroma/
173
+
174
+ Args:
175
+ transformer ([`ChromaTransformer2DModel`]):
176
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
177
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
178
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
179
+ vae ([`AutoencoderKL`]):
180
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
181
+ text_encoder ([`T5EncoderModel`]):
182
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
183
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
184
+ tokenizer (`T5TokenizerFast`):
185
+ Second Tokenizer of class
186
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
187
+ """
188
+
189
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
190
+ _optional_components = ["image_encoder", "feature_extractor"]
191
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
192
+
193
+ def __init__(
194
+ self,
195
+ scheduler: FlowMatchEulerDiscreteScheduler,
196
+ vae: AutoencoderKL,
197
+ text_encoder: T5EncoderModel,
198
+ tokenizer: T5TokenizerFast,
199
+ transformer: ChromaTransformer2DModel,
200
+ image_encoder: CLIPVisionModelWithProjection = None,
201
+ feature_extractor: CLIPImageProcessor = None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.register_modules(
206
+ vae=vae,
207
+ text_encoder=text_encoder,
208
+ tokenizer=tokenizer,
209
+ transformer=transformer,
210
+ scheduler=scheduler,
211
+ image_encoder=image_encoder,
212
+ feature_extractor=feature_extractor,
213
+ )
214
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
215
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
216
+
217
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
218
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
219
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
220
+ self.default_sample_size = 128
221
+
222
+ def _get_t5_prompt_embeds(
223
+ self,
224
+ prompt: Union[str, List[str]] = None,
225
+ num_images_per_prompt: int = 1,
226
+ max_sequence_length: int = 512,
227
+ device: Optional[torch.device] = None,
228
+ dtype: Optional[torch.dtype] = None,
229
+ ):
230
+ device = device or self._execution_device
231
+ dtype = dtype or self.text_encoder.dtype
232
+
233
+ prompt = [prompt] if isinstance(prompt, str) else prompt
234
+ batch_size = len(prompt)
235
+
236
+ if isinstance(self, TextualInversionLoaderMixin):
237
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
238
+
239
+ text_inputs = self.tokenizer(
240
+ prompt,
241
+ padding="max_length",
242
+ max_length=max_sequence_length,
243
+ truncation=True,
244
+ return_length=False,
245
+ return_overflowing_tokens=False,
246
+ return_tensors="pt",
247
+ )
248
+ text_input_ids = text_inputs.input_ids
249
+ attention_mask = text_inputs.attention_mask.clone()
250
+
251
+ # Chroma requires the attention mask to include one padding token
252
+ seq_lengths = attention_mask.sum(dim=1)
253
+ mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
254
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
255
+
256
+ prompt_embeds = self.text_encoder(
257
+ text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
258
+ )[0]
259
+
260
+ dtype = self.text_encoder.dtype
261
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
262
+ attention_mask = attention_mask.to(dtype=dtype, device=device)
263
+
264
+ _, seq_len, _ = prompt_embeds.shape
265
+
266
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
267
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
268
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
269
+
270
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
271
+ attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
272
+
273
+ return prompt_embeds, attention_mask
274
+
275
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
276
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
277
+ if isinstance(generator, list):
278
+ image_latents = [
279
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
280
+ for i in range(image.shape[0])
281
+ ]
282
+ image_latents = torch.cat(image_latents, dim=0)
283
+ else:
284
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
285
+
286
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
287
+
288
+ return image_latents
289
+
290
+ def encode_prompt(
291
+ self,
292
+ prompt: Union[str, List[str]],
293
+ negative_prompt: Union[str, List[str]] = None,
294
+ device: Optional[torch.device] = None,
295
+ num_images_per_prompt: int = 1,
296
+ prompt_embeds: Optional[torch.Tensor] = None,
297
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
298
+ prompt_attention_mask: Optional[torch.Tensor] = None,
299
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
300
+ do_classifier_free_guidance: bool = True,
301
+ max_sequence_length: int = 512,
302
+ lora_scale: Optional[float] = None,
303
+ ):
304
+ r"""
305
+
306
+ Args:
307
+ prompt (`str` or `List[str]`, *optional*):
308
+ prompt to be encoded
309
+ negative_prompt (`str` or `List[str]`, *optional*):
310
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
311
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
312
+ device: (`torch.device`):
313
+ torch device
314
+ num_images_per_prompt (`int`):
315
+ number of images that should be generated per prompt
316
+ prompt_embeds (`torch.Tensor`, *optional*):
317
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
318
+ provided, text embeddings will be generated from `prompt` input argument.
319
+ lora_scale (`float`, *optional*):
320
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
321
+ """
322
+ device = device or self._execution_device
323
+
324
+ # set lora scale so that monkey patched LoRA
325
+ # function of text encoder can correctly access it
326
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
327
+ self._lora_scale = lora_scale
328
+
329
+ # dynamically adjust the LoRA scale
330
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
331
+ scale_lora_layers(self.text_encoder, lora_scale)
332
+
333
+ prompt = [prompt] if isinstance(prompt, str) else prompt
334
+
335
+ if prompt is not None:
336
+ batch_size = len(prompt)
337
+ else:
338
+ batch_size = prompt_embeds.shape[0]
339
+
340
+ if prompt_embeds is None:
341
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
342
+ prompt=prompt,
343
+ num_images_per_prompt=num_images_per_prompt,
344
+ max_sequence_length=max_sequence_length,
345
+ device=device,
346
+ )
347
+
348
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
349
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
350
+ negative_text_ids = None
351
+
352
+ if do_classifier_free_guidance:
353
+ if negative_prompt_embeds is None:
354
+ negative_prompt = negative_prompt or ""
355
+ negative_prompt = (
356
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
357
+ )
358
+
359
+ if prompt is not None and type(prompt) is not type(negative_prompt):
360
+ raise TypeError(
361
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
362
+ f" {type(prompt)}."
363
+ )
364
+ elif batch_size != len(negative_prompt):
365
+ raise ValueError(
366
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
367
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
368
+ " the batch size of `prompt`."
369
+ )
370
+
371
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
372
+ prompt=negative_prompt,
373
+ num_images_per_prompt=num_images_per_prompt,
374
+ max_sequence_length=max_sequence_length,
375
+ device=device,
376
+ )
377
+
378
+ negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
379
+
380
+ if self.text_encoder is not None:
381
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382
+ # Retrieve the original scale by scaling back the LoRA layers
383
+ unscale_lora_layers(self.text_encoder, lora_scale)
384
+
385
+ return (
386
+ prompt_embeds,
387
+ text_ids,
388
+ prompt_attention_mask,
389
+ negative_prompt_embeds,
390
+ negative_text_ids,
391
+ negative_prompt_attention_mask,
392
+ )
393
+
394
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
395
+ def encode_image(self, image, device, num_images_per_prompt):
396
+ dtype = next(self.image_encoder.parameters()).dtype
397
+
398
+ if not isinstance(image, torch.Tensor):
399
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
400
+
401
+ image = image.to(device=device, dtype=dtype)
402
+ image_embeds = self.image_encoder(image).image_embeds
403
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
404
+ return image_embeds
405
+
406
+ def prepare_ip_adapter_image_embeds(
407
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
408
+ ):
409
+ device = device or self._execution_device
410
+
411
+ image_embeds = []
412
+ if ip_adapter_image_embeds is None:
413
+ if not isinstance(ip_adapter_image, list):
414
+ ip_adapter_image = [ip_adapter_image]
415
+
416
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
417
+ raise ValueError(
418
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
419
+ )
420
+
421
+ for single_ip_adapter_image in ip_adapter_image:
422
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
423
+ image_embeds.append(single_image_embeds[None, :])
424
+ else:
425
+ if not isinstance(ip_adapter_image_embeds, list):
426
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
427
+
428
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
429
+ raise ValueError(
430
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
431
+ )
432
+
433
+ for single_image_embeds in ip_adapter_image_embeds:
434
+ image_embeds.append(single_image_embeds)
435
+
436
+ ip_adapter_image_embeds = []
437
+ for single_image_embeds in image_embeds:
438
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
439
+ single_image_embeds = single_image_embeds.to(device=device)
440
+ ip_adapter_image_embeds.append(single_image_embeds)
441
+
442
+ return ip_adapter_image_embeds
443
+
444
+ def check_inputs(
445
+ self,
446
+ prompt,
447
+ height,
448
+ width,
449
+ strength,
450
+ negative_prompt=None,
451
+ prompt_embeds=None,
452
+ negative_prompt_embeds=None,
453
+ prompt_attention_mask=None,
454
+ negative_prompt_attention_mask=None,
455
+ callback_on_step_end_tensor_inputs=None,
456
+ max_sequence_length=None,
457
+ ):
458
+ if strength < 0 or strength > 1:
459
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
460
+
461
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
462
+ logger.warning(
463
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
464
+ )
465
+
466
+ if callback_on_step_end_tensor_inputs is not None and not all(
467
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
468
+ ):
469
+ raise ValueError(
470
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
471
+ )
472
+
473
+ if prompt is not None and prompt_embeds is not None:
474
+ raise ValueError(
475
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
476
+ " only forward one of the two."
477
+ )
478
+ elif prompt is None and prompt_embeds is None:
479
+ raise ValueError(
480
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
481
+ )
482
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
483
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
484
+
485
+ if negative_prompt is not None and negative_prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
488
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
489
+ )
490
+
491
+ if prompt_embeds is not None and prompt_attention_mask is None:
492
+ raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
493
+
494
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
495
+ raise ValueError(
496
+ "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
497
+ )
498
+
499
+ if max_sequence_length is not None and max_sequence_length > 512:
500
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
501
+
502
+ @staticmethod
503
+ def _prepare_latent_image_ids(height, width, device, dtype):
504
+ latent_image_ids = torch.zeros(height, width, 3)
505
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
506
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
507
+
508
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
509
+
510
+ latent_image_ids = latent_image_ids.reshape(
511
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
512
+ )
513
+
514
+ return latent_image_ids.to(device=device, dtype=dtype)
515
+
516
+ @staticmethod
517
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
518
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
519
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
520
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
521
+
522
+ return latents
523
+
524
+ @staticmethod
525
+ def _unpack_latents(latents, height, width, vae_scale_factor):
526
+ batch_size, num_patches, channels = latents.shape
527
+
528
+ # VAE applies 8x compression on images but we must also account for packing which requires
529
+ # latent height and width to be divisible by 2.
530
+ height = 2 * (int(height) // (vae_scale_factor * 2))
531
+ width = 2 * (int(width) // (vae_scale_factor * 2))
532
+
533
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
534
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
535
+
536
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
537
+
538
+ return latents
539
+
540
+ def enable_vae_slicing(self):
541
+ r"""
542
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
543
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
544
+ """
545
+ self.vae.enable_slicing()
546
+
547
+ def disable_vae_slicing(self):
548
+ r"""
549
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
550
+ computing decoding in one step.
551
+ """
552
+ self.vae.disable_slicing()
553
+
554
+ def enable_vae_tiling(self):
555
+ r"""
556
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
557
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
558
+ processing larger images.
559
+ """
560
+ self.vae.enable_tiling()
561
+
562
+ def disable_vae_tiling(self):
563
+ r"""
564
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
565
+ computing decoding in one step.
566
+ """
567
+ self.vae.disable_tiling()
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
570
+ def get_timesteps(self, num_inference_steps, strength, device):
571
+ # get the original timestep using init_timestep
572
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
573
+
574
+ t_start = int(max(num_inference_steps - init_timestep, 0))
575
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
576
+ if hasattr(self.scheduler, "set_begin_index"):
577
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
578
+
579
+ return timesteps, num_inference_steps - t_start
580
+
581
+ def prepare_latents(
582
+ self,
583
+ image,
584
+ timestep,
585
+ batch_size,
586
+ num_channels_latents,
587
+ height,
588
+ width,
589
+ dtype,
590
+ device,
591
+ generator,
592
+ latents=None,
593
+ ):
594
+ if isinstance(generator, list) and len(generator) != batch_size:
595
+ raise ValueError(
596
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
597
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
598
+ )
599
+
600
+ # VAE applies 8x compression on images but we must also account for packing which requires
601
+ # latent height and width to be divisible by 2.
602
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
603
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
604
+ shape = (batch_size, num_channels_latents, height, width)
605
+ latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
606
+
607
+ if latents is not None:
608
+ return latents.to(device=device, dtype=dtype), latent_image_ids
609
+
610
+ image = image.to(device=device, dtype=dtype)
611
+ if image.shape[1] != self.latent_channels:
612
+ image_latents = self._encode_vae_image(image=image, generator=generator)
613
+ else:
614
+ image_latents = image
615
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
616
+ # expand init_latents for batch_size
617
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
618
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
619
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
620
+ raise ValueError(
621
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
622
+ )
623
+ else:
624
+ image_latents = torch.cat([image_latents], dim=0)
625
+
626
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
627
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
628
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
629
+ return latents, latent_image_ids
630
+
631
+ def _prepare_attention_mask(
632
+ self,
633
+ batch_size,
634
+ sequence_length,
635
+ dtype,
636
+ attention_mask=None,
637
+ ):
638
+ if attention_mask is None:
639
+ return attention_mask
640
+
641
+ # Extend the prompt attention mask to account for image tokens in the final sequence
642
+ attention_mask = torch.cat(
643
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
644
+ dim=1,
645
+ )
646
+ attention_mask = attention_mask.to(dtype)
647
+
648
+ return attention_mask
649
+
650
+ @property
651
+ def guidance_scale(self):
652
+ return self._guidance_scale
653
+
654
+ @property
655
+ def joint_attention_kwargs(self):
656
+ return self._joint_attention_kwargs
657
+
658
+ @property
659
+ def do_classifier_free_guidance(self):
660
+ return self._guidance_scale > 1
661
+
662
+ @property
663
+ def num_timesteps(self):
664
+ return self._num_timesteps
665
+
666
+ @property
667
+ def current_timestep(self):
668
+ return self._current_timestep
669
+
670
+ @property
671
+ def interrupt(self):
672
+ return self._interrupt
673
+
674
+ @torch.no_grad()
675
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
676
+ def __call__(
677
+ self,
678
+ prompt: Union[str, List[str]] = None,
679
+ negative_prompt: Union[str, List[str]] = None,
680
+ image: PipelineImageInput = None,
681
+ height: Optional[int] = None,
682
+ width: Optional[int] = None,
683
+ num_inference_steps: int = 35,
684
+ sigmas: Optional[List[float]] = None,
685
+ guidance_scale: float = 5.0,
686
+ strength: float = 0.9,
687
+ num_images_per_prompt: Optional[int] = 1,
688
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
689
+ latents: Optional[torch.Tensor] = None,
690
+ prompt_embeds: Optional[torch.Tensor] = None,
691
+ ip_adapter_image: Optional[PipelineImageInput] = None,
692
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
693
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
694
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
695
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
696
+ prompt_attention_mask: Optional[torch.Tensor] = None,
697
+ negative_prompt_attention_mask: Optional[torch.tensor] = None,
698
+ output_type: Optional[str] = "pil",
699
+ return_dict: bool = True,
700
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
701
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
702
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
703
+ max_sequence_length: int = 512,
704
+ ):
705
+ r"""
706
+ Function invoked when calling the pipeline for generation.
707
+
708
+ Args:
709
+ prompt (`str` or `List[str]`, *optional*):
710
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
711
+ instead.
712
+ negative_prompt (`str` or `List[str]`, *optional*):
713
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
714
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
715
+ not greater than `1`).
716
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
717
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
718
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
719
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
720
+ num_inference_steps (`int`, *optional*, defaults to 35):
721
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
722
+ expense of slower inference.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 5.0):
728
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
729
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
730
+
731
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
732
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
733
+ strength (`float, *optional*, defaults to 0.9):
734
+ Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
735
+ be used as a starting point, adding more noise to it the larger the strength. The number of denoising
736
+ steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum
737
+ and the denoising process will run for the full number of iterations specified in num_inference_steps.
738
+ A value of 1, therefore, essentially ignores image.
739
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
740
+ The number of images to generate per prompt.
741
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
742
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
743
+ to make generation deterministic.
744
+ latents (`torch.Tensor`, *optional*):
745
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
746
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
747
+ tensor will be generated by sampling using the supplied random `generator`.
748
+ prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
750
+ provided, text embeddings will be generated from `prompt` input argument.
751
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
752
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
753
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
754
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
755
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
756
+ negative_ip_adapter_image:
757
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
758
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
759
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
760
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
761
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
762
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
763
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
764
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
765
+ argument.
766
+ prompt_attention_mask (torch.Tensor, *optional*):
767
+ Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
768
+ Chroma requires a single padding token remain unmasked. Please refer to
769
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
770
+ negative_prompt_attention_mask (torch.Tensor, *optional*):
771
+ Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
772
+ prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
773
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
774
+ output_type (`str`, *optional*, defaults to `"pil"`):
775
+ The output format of the generate image. Choose between
776
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
777
+ return_dict (`bool`, *optional*, defaults to `True`):
778
+ Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
779
+ joint_attention_kwargs (`dict`, *optional*):
780
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
781
+ `self.processor` in
782
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
783
+ callback_on_step_end (`Callable`, *optional*):
784
+ A function that calls at the end of each denoising steps during the inference. The function is called
785
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
786
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
787
+ `callback_on_step_end_tensor_inputs`.
788
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
789
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
790
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
791
+ `._callback_tensor_inputs` attribute of your pipeline class.
792
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
793
+
794
+ Examples:
795
+
796
+ Returns:
797
+ [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
798
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
799
+ generated images.
800
+ """
801
+
802
+ height = height or self.default_sample_size * self.vae_scale_factor
803
+ width = width or self.default_sample_size * self.vae_scale_factor
804
+
805
+ # 1. Check inputs. Raise error if not correct
806
+ self.check_inputs(
807
+ prompt,
808
+ height,
809
+ width,
810
+ strength,
811
+ negative_prompt=negative_prompt,
812
+ prompt_embeds=prompt_embeds,
813
+ negative_prompt_embeds=negative_prompt_embeds,
814
+ prompt_attention_mask=prompt_attention_mask,
815
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
816
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
817
+ max_sequence_length=max_sequence_length,
818
+ )
819
+
820
+ self._guidance_scale = guidance_scale
821
+ self._joint_attention_kwargs = joint_attention_kwargs
822
+ self._current_timestep = None
823
+ self._interrupt = False
824
+
825
+ # 2. Preprocess image
826
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
827
+ init_image = init_image.to(dtype=torch.float32)
828
+
829
+ # 3. Define call parameters
830
+ if prompt is not None and isinstance(prompt, str):
831
+ batch_size = 1
832
+ elif prompt is not None and isinstance(prompt, list):
833
+ batch_size = len(prompt)
834
+ else:
835
+ batch_size = prompt_embeds.shape[0]
836
+
837
+ device = self._execution_device
838
+ lora_scale = (
839
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
840
+ )
841
+
842
+ (
843
+ prompt_embeds,
844
+ text_ids,
845
+ prompt_attention_mask,
846
+ negative_prompt_embeds,
847
+ negative_text_ids,
848
+ negative_prompt_attention_mask,
849
+ ) = self.encode_prompt(
850
+ prompt=prompt,
851
+ negative_prompt=negative_prompt,
852
+ prompt_embeds=prompt_embeds,
853
+ negative_prompt_embeds=negative_prompt_embeds,
854
+ prompt_attention_mask=prompt_attention_mask,
855
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
856
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
857
+ device=device,
858
+ num_images_per_prompt=num_images_per_prompt,
859
+ max_sequence_length=max_sequence_length,
860
+ lora_scale=lora_scale,
861
+ )
862
+
863
+ # 4. Prepare timesteps
864
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
865
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
866
+ mu = calculate_shift(
867
+ image_seq_len,
868
+ self.scheduler.config.get("base_image_seq_len", 256),
869
+ self.scheduler.config.get("max_image_seq_len", 4096),
870
+ self.scheduler.config.get("base_shift", 0.5),
871
+ self.scheduler.config.get("max_shift", 1.15),
872
+ )
873
+ timesteps, num_inference_steps = retrieve_timesteps(
874
+ self.scheduler,
875
+ num_inference_steps,
876
+ device,
877
+ sigmas=sigmas,
878
+ mu=mu,
879
+ )
880
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
881
+
882
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
883
+ self._num_timesteps = len(timesteps)
884
+
885
+ if num_inference_steps < 1:
886
+ raise ValueError(
887
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
888
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
889
+ )
890
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
891
+
892
+ # 5. Prepare latent variables
893
+ num_channels_latents = self.transformer.config.in_channels // 4
894
+ latents, latent_image_ids = self.prepare_latents(
895
+ init_image,
896
+ latent_timestep,
897
+ batch_size * num_images_per_prompt,
898
+ num_channels_latents,
899
+ height,
900
+ width,
901
+ prompt_embeds.dtype,
902
+ device,
903
+ generator,
904
+ latents,
905
+ )
906
+
907
+ attention_mask = self._prepare_attention_mask(
908
+ batch_size=latents.shape[0],
909
+ sequence_length=image_seq_len,
910
+ dtype=latents.dtype,
911
+ attention_mask=prompt_attention_mask,
912
+ )
913
+ negative_attention_mask = self._prepare_attention_mask(
914
+ batch_size=latents.shape[0],
915
+ sequence_length=image_seq_len,
916
+ dtype=latents.dtype,
917
+ attention_mask=negative_prompt_attention_mask,
918
+ )
919
+
920
+ # 6. Prepare image embeddings
921
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
922
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
923
+ ):
924
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
925
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
926
+
927
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
928
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
929
+ ):
930
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
931
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
932
+
933
+ if self.joint_attention_kwargs is None:
934
+ self._joint_attention_kwargs = {}
935
+
936
+ image_embeds = None
937
+ negative_image_embeds = None
938
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
939
+ image_embeds = self.prepare_ip_adapter_image_embeds(
940
+ ip_adapter_image,
941
+ ip_adapter_image_embeds,
942
+ device,
943
+ batch_size * num_images_per_prompt,
944
+ )
945
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
946
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
947
+ negative_ip_adapter_image,
948
+ negative_ip_adapter_image_embeds,
949
+ device,
950
+ batch_size * num_images_per_prompt,
951
+ )
952
+
953
+ # 6. Denoising loop
954
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
955
+ for i, t in enumerate(timesteps):
956
+ if self.interrupt:
957
+ continue
958
+
959
+ self._current_timestep = t
960
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
961
+ timestep = t.expand(latents.shape[0])
962
+
963
+ if image_embeds is not None:
964
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
965
+
966
+ noise_pred = self.transformer(
967
+ hidden_states=latents,
968
+ timestep=timestep / 1000,
969
+ encoder_hidden_states=prompt_embeds,
970
+ txt_ids=text_ids,
971
+ img_ids=latent_image_ids,
972
+ attention_mask=attention_mask,
973
+ joint_attention_kwargs=self.joint_attention_kwargs,
974
+ return_dict=False,
975
+ )[0]
976
+
977
+ if self.do_classifier_free_guidance:
978
+ if negative_image_embeds is not None:
979
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
980
+
981
+ noise_pred_uncond = self.transformer(
982
+ hidden_states=latents,
983
+ timestep=timestep / 1000,
984
+ encoder_hidden_states=negative_prompt_embeds,
985
+ txt_ids=negative_text_ids,
986
+ img_ids=latent_image_ids,
987
+ attention_mask=negative_attention_mask,
988
+ joint_attention_kwargs=self.joint_attention_kwargs,
989
+ return_dict=False,
990
+ )[0]
991
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
992
+
993
+ # compute the previous noisy sample x_t -> x_t-1
994
+ latents_dtype = latents.dtype
995
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
996
+
997
+ if latents.dtype != latents_dtype:
998
+ if torch.backends.mps.is_available():
999
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1000
+ latents = latents.to(latents_dtype)
1001
+
1002
+ if callback_on_step_end is not None:
1003
+ callback_kwargs = {}
1004
+ for k in callback_on_step_end_tensor_inputs:
1005
+ callback_kwargs[k] = locals()[k]
1006
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1007
+
1008
+ latents = callback_outputs.pop("latents", latents)
1009
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1010
+
1011
+ # call the callback, if provided
1012
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1013
+ progress_bar.update()
1014
+
1015
+ if XLA_AVAILABLE:
1016
+ xm.mark_step()
1017
+
1018
+ self._current_timestep = None
1019
+
1020
+ if output_type == "latent":
1021
+ image = latents
1022
+ else:
1023
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1024
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1025
+ image = self.vae.decode(latents, return_dict=False)[0]
1026
+ image = self.image_processor.postprocess(image, output_type=output_type)
1027
+
1028
+ # Offload all models
1029
+ self.maybe_free_model_hooks()
1030
+
1031
+ if not return_dict:
1032
+ return (image,)
1033
+
1034
+ return ChromaPipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/chroma/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class ChromaPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
26
+ _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
27
+ _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
28
+ _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
29
+
30
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
+ try:
32
+ if not (is_transformers_available() and is_torch_available()):
33
+ raise OptionalDependencyNotAvailable()
34
+
35
+ except OptionalDependencyNotAvailable:
36
+ from ...utils.dummy_torch_and_transformers_objects import *
37
+ else:
38
+ from .pipeline_cogvideox import CogVideoXPipeline
39
+ from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
40
+ from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
41
+ from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
42
+
43
+ else:
44
+ import sys
45
+
46
+ sys.modules[__name__] = _LazyModule(
47
+ __name__,
48
+ globals()["__file__"],
49
+ _import_structure,
50
+ module_spec=__spec__,
51
+ )
52
+
53
+ for name, value in _dummy_objects.items():
54
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from transformers import T5EncoderModel, T5Tokenizer
22
+
23
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from ...loaders import CogVideoXLoraLoaderMixin
25
+ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
+ from ...models.embeddings import get_3d_rotary_pos_embed
27
+ from ...pipelines.pipeline_utils import DiffusionPipeline
28
+ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
30
+ from ...utils.torch_utils import randn_tensor
31
+ from ...video_processor import VideoProcessor
32
+ from .pipeline_output import CogVideoXPipelineOutput
33
+
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ EXAMPLE_DOC_STRING = """
46
+ Examples:
47
+ ```python
48
+ >>> import torch
49
+ >>> from diffusers import CogVideoXPipeline
50
+ >>> from diffusers.utils import export_to_video
51
+
52
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
53
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
54
+ >>> prompt = (
55
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
56
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
57
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
58
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
59
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
60
+ ... "atmosphere of this unique musical performance."
61
+ ... )
62
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
63
+ >>> export_to_video(video, "output.mp4", fps=8)
64
+ ```
65
+ """
66
+
67
+
68
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
69
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
70
+ tw = tgt_width
71
+ th = tgt_height
72
+ h, w = src
73
+ r = h / w
74
+ if r > (th / tw):
75
+ resize_height = th
76
+ resize_width = int(round(th / h * w))
77
+ else:
78
+ resize_width = tw
79
+ resize_height = int(round(tw / w * h))
80
+
81
+ crop_top = int(round((th - resize_height) / 2.0))
82
+ crop_left = int(round((tw - resize_width) / 2.0))
83
+
84
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
148
+ r"""
149
+ Pipeline for text-to-video generation using CogVideoX.
150
+
151
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
152
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
153
+
154
+ Args:
155
+ vae ([`AutoencoderKL`]):
156
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
157
+ text_encoder ([`T5EncoderModel`]):
158
+ Frozen text-encoder. CogVideoX uses
159
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
160
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
161
+ tokenizer (`T5Tokenizer`):
162
+ Tokenizer of class
163
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
164
+ transformer ([`CogVideoXTransformer3DModel`]):
165
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
166
+ scheduler ([`SchedulerMixin`]):
167
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
168
+ """
169
+
170
+ _optional_components = []
171
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
172
+
173
+ _callback_tensor_inputs = [
174
+ "latents",
175
+ "prompt_embeds",
176
+ "negative_prompt_embeds",
177
+ ]
178
+
179
+ def __init__(
180
+ self,
181
+ tokenizer: T5Tokenizer,
182
+ text_encoder: T5EncoderModel,
183
+ vae: AutoencoderKLCogVideoX,
184
+ transformer: CogVideoXTransformer3DModel,
185
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
186
+ ):
187
+ super().__init__()
188
+
189
+ self.register_modules(
190
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
191
+ )
192
+ self.vae_scale_factor_spatial = (
193
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
194
+ )
195
+ self.vae_scale_factor_temporal = (
196
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
197
+ )
198
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
199
+
200
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
201
+
202
+ def _get_t5_prompt_embeds(
203
+ self,
204
+ prompt: Union[str, List[str]] = None,
205
+ num_videos_per_prompt: int = 1,
206
+ max_sequence_length: int = 226,
207
+ device: Optional[torch.device] = None,
208
+ dtype: Optional[torch.dtype] = None,
209
+ ):
210
+ device = device or self._execution_device
211
+ dtype = dtype or self.text_encoder.dtype
212
+
213
+ prompt = [prompt] if isinstance(prompt, str) else prompt
214
+ batch_size = len(prompt)
215
+
216
+ text_inputs = self.tokenizer(
217
+ prompt,
218
+ padding="max_length",
219
+ max_length=max_sequence_length,
220
+ truncation=True,
221
+ add_special_tokens=True,
222
+ return_tensors="pt",
223
+ )
224
+ text_input_ids = text_inputs.input_ids
225
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
226
+
227
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
228
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
229
+ logger.warning(
230
+ "The following part of your input was truncated because `max_sequence_length` is set to "
231
+ f" {max_sequence_length} tokens: {removed_text}"
232
+ )
233
+
234
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
235
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
236
+
237
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
238
+ _, seq_len, _ = prompt_embeds.shape
239
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
240
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
241
+
242
+ return prompt_embeds
243
+
244
+ def encode_prompt(
245
+ self,
246
+ prompt: Union[str, List[str]],
247
+ negative_prompt: Optional[Union[str, List[str]]] = None,
248
+ do_classifier_free_guidance: bool = True,
249
+ num_videos_per_prompt: int = 1,
250
+ prompt_embeds: Optional[torch.Tensor] = None,
251
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
252
+ max_sequence_length: int = 226,
253
+ device: Optional[torch.device] = None,
254
+ dtype: Optional[torch.dtype] = None,
255
+ ):
256
+ r"""
257
+ Encodes the prompt into text encoder hidden states.
258
+
259
+ Args:
260
+ prompt (`str` or `List[str]`, *optional*):
261
+ prompt to be encoded
262
+ negative_prompt (`str` or `List[str]`, *optional*):
263
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
264
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
265
+ less than `1`).
266
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
267
+ Whether to use classifier free guidance or not.
268
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
269
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
274
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
275
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
276
+ argument.
277
+ device: (`torch.device`, *optional*):
278
+ torch device
279
+ dtype: (`torch.dtype`, *optional*):
280
+ torch dtype
281
+ """
282
+ device = device or self._execution_device
283
+
284
+ prompt = [prompt] if isinstance(prompt, str) else prompt
285
+ if prompt is not None:
286
+ batch_size = len(prompt)
287
+ else:
288
+ batch_size = prompt_embeds.shape[0]
289
+
290
+ if prompt_embeds is None:
291
+ prompt_embeds = self._get_t5_prompt_embeds(
292
+ prompt=prompt,
293
+ num_videos_per_prompt=num_videos_per_prompt,
294
+ max_sequence_length=max_sequence_length,
295
+ device=device,
296
+ dtype=dtype,
297
+ )
298
+
299
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
300
+ negative_prompt = negative_prompt or ""
301
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
302
+
303
+ if prompt is not None and type(prompt) is not type(negative_prompt):
304
+ raise TypeError(
305
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
306
+ f" {type(prompt)}."
307
+ )
308
+ elif batch_size != len(negative_prompt):
309
+ raise ValueError(
310
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
311
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
312
+ " the batch size of `prompt`."
313
+ )
314
+
315
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
316
+ prompt=negative_prompt,
317
+ num_videos_per_prompt=num_videos_per_prompt,
318
+ max_sequence_length=max_sequence_length,
319
+ device=device,
320
+ dtype=dtype,
321
+ )
322
+
323
+ return prompt_embeds, negative_prompt_embeds
324
+
325
+ def prepare_latents(
326
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
327
+ ):
328
+ if isinstance(generator, list) and len(generator) != batch_size:
329
+ raise ValueError(
330
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
331
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
332
+ )
333
+
334
+ shape = (
335
+ batch_size,
336
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
337
+ num_channels_latents,
338
+ height // self.vae_scale_factor_spatial,
339
+ width // self.vae_scale_factor_spatial,
340
+ )
341
+
342
+ if latents is None:
343
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
344
+ else:
345
+ latents = latents.to(device)
346
+
347
+ # scale the initial noise by the standard deviation required by the scheduler
348
+ latents = latents * self.scheduler.init_noise_sigma
349
+ return latents
350
+
351
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
352
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
353
+ latents = 1 / self.vae_scaling_factor_image * latents
354
+
355
+ frames = self.vae.decode(latents).sample
356
+ return frames
357
+
358
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
359
+ def prepare_extra_step_kwargs(self, generator, eta):
360
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
361
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
362
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
363
+ # and should be between [0, 1]
364
+
365
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
366
+ extra_step_kwargs = {}
367
+ if accepts_eta:
368
+ extra_step_kwargs["eta"] = eta
369
+
370
+ # check if the scheduler accepts generator
371
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
372
+ if accepts_generator:
373
+ extra_step_kwargs["generator"] = generator
374
+ return extra_step_kwargs
375
+
376
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
377
+ def check_inputs(
378
+ self,
379
+ prompt,
380
+ height,
381
+ width,
382
+ negative_prompt,
383
+ callback_on_step_end_tensor_inputs,
384
+ prompt_embeds=None,
385
+ negative_prompt_embeds=None,
386
+ ):
387
+ if height % 8 != 0 or width % 8 != 0:
388
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
389
+
390
+ if callback_on_step_end_tensor_inputs is not None and not all(
391
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
392
+ ):
393
+ raise ValueError(
394
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
395
+ )
396
+ if prompt is not None and prompt_embeds is not None:
397
+ raise ValueError(
398
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
399
+ " only forward one of the two."
400
+ )
401
+ elif prompt is None and prompt_embeds is None:
402
+ raise ValueError(
403
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
404
+ )
405
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
406
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
407
+
408
+ if prompt is not None and negative_prompt_embeds is not None:
409
+ raise ValueError(
410
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
411
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
412
+ )
413
+
414
+ if negative_prompt is not None and negative_prompt_embeds is not None:
415
+ raise ValueError(
416
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
417
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
418
+ )
419
+
420
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
421
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
422
+ raise ValueError(
423
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
424
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
425
+ f" {negative_prompt_embeds.shape}."
426
+ )
427
+
428
+ def fuse_qkv_projections(self) -> None:
429
+ r"""Enables fused QKV projections."""
430
+ self.fusing_transformer = True
431
+ self.transformer.fuse_qkv_projections()
432
+
433
+ def unfuse_qkv_projections(self) -> None:
434
+ r"""Disable QKV projection fusion if enabled."""
435
+ if not self.fusing_transformer:
436
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
437
+ else:
438
+ self.transformer.unfuse_qkv_projections()
439
+ self.fusing_transformer = False
440
+
441
+ def _prepare_rotary_positional_embeddings(
442
+ self,
443
+ height: int,
444
+ width: int,
445
+ num_frames: int,
446
+ device: torch.device,
447
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
448
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
449
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
450
+
451
+ p = self.transformer.config.patch_size
452
+ p_t = self.transformer.config.patch_size_t
453
+
454
+ base_size_width = self.transformer.config.sample_width // p
455
+ base_size_height = self.transformer.config.sample_height // p
456
+
457
+ if p_t is None:
458
+ # CogVideoX 1.0
459
+ grid_crops_coords = get_resize_crop_region_for_grid(
460
+ (grid_height, grid_width), base_size_width, base_size_height
461
+ )
462
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
463
+ embed_dim=self.transformer.config.attention_head_dim,
464
+ crops_coords=grid_crops_coords,
465
+ grid_size=(grid_height, grid_width),
466
+ temporal_size=num_frames,
467
+ device=device,
468
+ )
469
+ else:
470
+ # CogVideoX 1.5
471
+ base_num_frames = (num_frames + p_t - 1) // p_t
472
+
473
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
474
+ embed_dim=self.transformer.config.attention_head_dim,
475
+ crops_coords=None,
476
+ grid_size=(grid_height, grid_width),
477
+ temporal_size=base_num_frames,
478
+ grid_type="slice",
479
+ max_size=(base_size_height, base_size_width),
480
+ device=device,
481
+ )
482
+
483
+ return freqs_cos, freqs_sin
484
+
485
+ @property
486
+ def guidance_scale(self):
487
+ return self._guidance_scale
488
+
489
+ @property
490
+ def num_timesteps(self):
491
+ return self._num_timesteps
492
+
493
+ @property
494
+ def attention_kwargs(self):
495
+ return self._attention_kwargs
496
+
497
+ @property
498
+ def current_timestep(self):
499
+ return self._current_timestep
500
+
501
+ @property
502
+ def interrupt(self):
503
+ return self._interrupt
504
+
505
+ @torch.no_grad()
506
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
507
+ def __call__(
508
+ self,
509
+ prompt: Optional[Union[str, List[str]]] = None,
510
+ negative_prompt: Optional[Union[str, List[str]]] = None,
511
+ height: Optional[int] = None,
512
+ width: Optional[int] = None,
513
+ num_frames: Optional[int] = None,
514
+ num_inference_steps: int = 50,
515
+ timesteps: Optional[List[int]] = None,
516
+ guidance_scale: float = 6,
517
+ use_dynamic_cfg: bool = False,
518
+ num_videos_per_prompt: int = 1,
519
+ eta: float = 0.0,
520
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
521
+ latents: Optional[torch.FloatTensor] = None,
522
+ prompt_embeds: Optional[torch.FloatTensor] = None,
523
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
524
+ output_type: str = "pil",
525
+ return_dict: bool = True,
526
+ attention_kwargs: Optional[Dict[str, Any]] = None,
527
+ callback_on_step_end: Optional[
528
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
529
+ ] = None,
530
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
531
+ max_sequence_length: int = 226,
532
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
533
+ """
534
+ Function invoked when calling the pipeline for generation.
535
+
536
+ Args:
537
+ prompt (`str` or `List[str]`, *optional*):
538
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
539
+ instead.
540
+ negative_prompt (`str` or `List[str]`, *optional*):
541
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
542
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
543
+ less than `1`).
544
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
545
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
546
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
547
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
548
+ num_frames (`int`, defaults to `48`):
549
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
550
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
551
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
552
+ needs to be satisfied is that of divisibility mentioned above.
553
+ num_inference_steps (`int`, *optional*, defaults to 50):
554
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
555
+ expense of slower inference.
556
+ timesteps (`List[int]`, *optional*):
557
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
558
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
559
+ passed will be used. Must be in descending order.
560
+ guidance_scale (`float`, *optional*, defaults to 7.0):
561
+ Guidance scale as defined in [Classifier-Free Diffusion
562
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
563
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
564
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
565
+ the text `prompt`, usually at the expense of lower image quality.
566
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
567
+ The number of videos to generate per prompt.
568
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
569
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
570
+ to make generation deterministic.
571
+ latents (`torch.FloatTensor`, *optional*):
572
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
573
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
574
+ tensor will be generated by sampling using the supplied random `generator`.
575
+ prompt_embeds (`torch.FloatTensor`, *optional*):
576
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
577
+ provided, text embeddings will be generated from `prompt` input argument.
578
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
579
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
580
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
581
+ argument.
582
+ output_type (`str`, *optional*, defaults to `"pil"`):
583
+ The output format of the generate image. Choose between
584
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
585
+ return_dict (`bool`, *optional*, defaults to `True`):
586
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
587
+ of a plain tuple.
588
+ attention_kwargs (`dict`, *optional*):
589
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
590
+ `self.processor` in
591
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
592
+ callback_on_step_end (`Callable`, *optional*):
593
+ A function that calls at the end of each denoising steps during the inference. The function is called
594
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
595
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
596
+ `callback_on_step_end_tensor_inputs`.
597
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
598
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
599
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
600
+ `._callback_tensor_inputs` attribute of your pipeline class.
601
+ max_sequence_length (`int`, defaults to `226`):
602
+ Maximum sequence length in encoded prompt. Must be consistent with
603
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
604
+
605
+ Examples:
606
+
607
+ Returns:
608
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
609
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
610
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
611
+ """
612
+
613
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
614
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
615
+
616
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
617
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
618
+ num_frames = num_frames or self.transformer.config.sample_frames
619
+
620
+ num_videos_per_prompt = 1
621
+
622
+ # 1. Check inputs. Raise error if not correct
623
+ self.check_inputs(
624
+ prompt,
625
+ height,
626
+ width,
627
+ negative_prompt,
628
+ callback_on_step_end_tensor_inputs,
629
+ prompt_embeds,
630
+ negative_prompt_embeds,
631
+ )
632
+ self._guidance_scale = guidance_scale
633
+ self._attention_kwargs = attention_kwargs
634
+ self._current_timestep = None
635
+ self._interrupt = False
636
+
637
+ # 2. Default call parameters
638
+ if prompt is not None and isinstance(prompt, str):
639
+ batch_size = 1
640
+ elif prompt is not None and isinstance(prompt, list):
641
+ batch_size = len(prompt)
642
+ else:
643
+ batch_size = prompt_embeds.shape[0]
644
+
645
+ device = self._execution_device
646
+
647
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
648
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
649
+ # corresponds to doing no classifier free guidance.
650
+ do_classifier_free_guidance = guidance_scale > 1.0
651
+
652
+ # 3. Encode input prompt
653
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
654
+ prompt,
655
+ negative_prompt,
656
+ do_classifier_free_guidance,
657
+ num_videos_per_prompt=num_videos_per_prompt,
658
+ prompt_embeds=prompt_embeds,
659
+ negative_prompt_embeds=negative_prompt_embeds,
660
+ max_sequence_length=max_sequence_length,
661
+ device=device,
662
+ )
663
+ if do_classifier_free_guidance:
664
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
665
+
666
+ # 4. Prepare timesteps
667
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
668
+ self._num_timesteps = len(timesteps)
669
+
670
+ # 5. Prepare latents
671
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
672
+
673
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
674
+ patch_size_t = self.transformer.config.patch_size_t
675
+ additional_frames = 0
676
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
677
+ additional_frames = patch_size_t - latent_frames % patch_size_t
678
+ num_frames += additional_frames * self.vae_scale_factor_temporal
679
+
680
+ latent_channels = self.transformer.config.in_channels
681
+ latents = self.prepare_latents(
682
+ batch_size * num_videos_per_prompt,
683
+ latent_channels,
684
+ num_frames,
685
+ height,
686
+ width,
687
+ prompt_embeds.dtype,
688
+ device,
689
+ generator,
690
+ latents,
691
+ )
692
+
693
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
694
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
695
+
696
+ # 7. Create rotary embeds if required
697
+ image_rotary_emb = (
698
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
699
+ if self.transformer.config.use_rotary_positional_embeddings
700
+ else None
701
+ )
702
+
703
+ # 8. Denoising loop
704
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
705
+
706
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
707
+ # for DPM-solver++
708
+ old_pred_original_sample = None
709
+ for i, t in enumerate(timesteps):
710
+ if self.interrupt:
711
+ continue
712
+
713
+ self._current_timestep = t
714
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
715
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
716
+
717
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
718
+ timestep = t.expand(latent_model_input.shape[0])
719
+
720
+ # predict noise model_output
721
+ with self.transformer.cache_context("cond_uncond"):
722
+ noise_pred = self.transformer(
723
+ hidden_states=latent_model_input,
724
+ encoder_hidden_states=prompt_embeds,
725
+ timestep=timestep,
726
+ image_rotary_emb=image_rotary_emb,
727
+ attention_kwargs=attention_kwargs,
728
+ return_dict=False,
729
+ )[0]
730
+ noise_pred = noise_pred.float()
731
+
732
+ # perform guidance
733
+ if use_dynamic_cfg:
734
+ self._guidance_scale = 1 + guidance_scale * (
735
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
736
+ )
737
+ if do_classifier_free_guidance:
738
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
739
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
740
+
741
+ # compute the previous noisy sample x_t -> x_t-1
742
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
743
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
744
+ else:
745
+ latents, old_pred_original_sample = self.scheduler.step(
746
+ noise_pred,
747
+ old_pred_original_sample,
748
+ t,
749
+ timesteps[i - 1] if i > 0 else None,
750
+ latents,
751
+ **extra_step_kwargs,
752
+ return_dict=False,
753
+ )
754
+ latents = latents.to(prompt_embeds.dtype)
755
+
756
+ # call the callback, if provided
757
+ if callback_on_step_end is not None:
758
+ callback_kwargs = {}
759
+ for k in callback_on_step_end_tensor_inputs:
760
+ callback_kwargs[k] = locals()[k]
761
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
762
+
763
+ latents = callback_outputs.pop("latents", latents)
764
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
765
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
766
+
767
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
768
+ progress_bar.update()
769
+
770
+ if XLA_AVAILABLE:
771
+ xm.mark_step()
772
+
773
+ self._current_timestep = None
774
+
775
+ if not output_type == "latent":
776
+ # Discard any padding frames that were added for CogVideoX 1.5
777
+ latents = latents[:, additional_frames:]
778
+ video = self.decode_latents(latents)
779
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
780
+ else:
781
+ video = latents
782
+
783
+ # Offload all models
784
+ self.maybe_free_model_hooks()
785
+
786
+ if not return_dict:
787
+ return (video,)
788
+
789
+ return CogVideoXPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from PIL import Image
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...loaders import CogVideoXLoraLoaderMixin
26
+ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
27
+ from ...models.embeddings import get_3d_rotary_pos_embed
28
+ from ...pipelines.pipeline_utils import DiffusionPipeline
29
+ from ...schedulers import KarrasDiffusionSchedulers
30
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
31
+ from ...utils.torch_utils import randn_tensor
32
+ from ...video_processor import VideoProcessor
33
+ from .pipeline_output import CogVideoXPipelineOutput
34
+
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```python
49
+ >>> import torch
50
+ >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler
51
+ >>> from diffusers.utils import export_to_video, load_video
52
+
53
+ >>> pipe = CogVideoXFunControlPipeline.from_pretrained(
54
+ ... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16
55
+ ... )
56
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
57
+ >>> pipe.to("cuda")
58
+
59
+ >>> control_video = load_video(
60
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
61
+ ... )
62
+ >>> prompt = (
63
+ ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
64
+ ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
65
+ ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
66
+ ... "moons, but the remainder of the scene is mostly realistic."
67
+ ... )
68
+
69
+ >>> video = pipe(prompt=prompt, control_video=control_video).frames[0]
70
+ >>> export_to_video(video, "output.mp4", fps=8)
71
+ ```
72
+ """
73
+
74
+
75
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid
76
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
77
+ tw = tgt_width
78
+ th = tgt_height
79
+ h, w = src
80
+ r = h / w
81
+ if r > (th / tw):
82
+ resize_height = th
83
+ resize_width = int(round(th / h * w))
84
+ else:
85
+ resize_width = tw
86
+ resize_height = int(round(tw / w * h))
87
+
88
+ crop_top = int(round((th - resize_height) / 2.0))
89
+ crop_left = int(round((tw - resize_width) / 2.0))
90
+
91
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
92
+
93
+
94
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
95
+ def retrieve_timesteps(
96
+ scheduler,
97
+ num_inference_steps: Optional[int] = None,
98
+ device: Optional[Union[str, torch.device]] = None,
99
+ timesteps: Optional[List[int]] = None,
100
+ sigmas: Optional[List[float]] = None,
101
+ **kwargs,
102
+ ):
103
+ r"""
104
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
105
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
106
+
107
+ Args:
108
+ scheduler (`SchedulerMixin`):
109
+ The scheduler to get timesteps from.
110
+ num_inference_steps (`int`):
111
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
112
+ must be `None`.
113
+ device (`str` or `torch.device`, *optional*):
114
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
115
+ timesteps (`List[int]`, *optional*):
116
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
117
+ `num_inference_steps` and `sigmas` must be `None`.
118
+ sigmas (`List[float]`, *optional*):
119
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
120
+ `num_inference_steps` and `timesteps` must be `None`.
121
+
122
+ Returns:
123
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
124
+ second element is the number of inference steps.
125
+ """
126
+ if timesteps is not None and sigmas is not None:
127
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
128
+ if timesteps is not None:
129
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
130
+ if not accepts_timesteps:
131
+ raise ValueError(
132
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
133
+ f" timestep schedules. Please check whether you are using the correct scheduler."
134
+ )
135
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ num_inference_steps = len(timesteps)
138
+ elif sigmas is not None:
139
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
140
+ if not accept_sigmas:
141
+ raise ValueError(
142
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
143
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
144
+ )
145
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ num_inference_steps = len(timesteps)
148
+ else:
149
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
150
+ timesteps = scheduler.timesteps
151
+ return timesteps, num_inference_steps
152
+
153
+
154
+ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
155
+ r"""
156
+ Pipeline for controlled text-to-video generation using CogVideoX Fun.
157
+
158
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
159
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
+
161
+ Args:
162
+ vae ([`AutoencoderKL`]):
163
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
164
+ text_encoder ([`T5EncoderModel`]):
165
+ Frozen text-encoder. CogVideoX uses
166
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
167
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
168
+ tokenizer (`T5Tokenizer`):
169
+ Tokenizer of class
170
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
171
+ transformer ([`CogVideoXTransformer3DModel`]):
172
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
173
+ scheduler ([`SchedulerMixin`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
175
+ """
176
+
177
+ _optional_components = []
178
+ model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
179
+
180
+ _callback_tensor_inputs = [
181
+ "latents",
182
+ "prompt_embeds",
183
+ "negative_prompt_embeds",
184
+ ]
185
+
186
+ def __init__(
187
+ self,
188
+ tokenizer: T5Tokenizer,
189
+ text_encoder: T5EncoderModel,
190
+ vae: AutoencoderKLCogVideoX,
191
+ transformer: CogVideoXTransformer3DModel,
192
+ scheduler: KarrasDiffusionSchedulers,
193
+ ):
194
+ super().__init__()
195
+
196
+ self.register_modules(
197
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
198
+ )
199
+ self.vae_scale_factor_spatial = (
200
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
201
+ )
202
+ self.vae_scale_factor_temporal = (
203
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
204
+ )
205
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
206
+
207
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
208
+
209
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
210
+ def _get_t5_prompt_embeds(
211
+ self,
212
+ prompt: Union[str, List[str]] = None,
213
+ num_videos_per_prompt: int = 1,
214
+ max_sequence_length: int = 226,
215
+ device: Optional[torch.device] = None,
216
+ dtype: Optional[torch.dtype] = None,
217
+ ):
218
+ device = device or self._execution_device
219
+ dtype = dtype or self.text_encoder.dtype
220
+
221
+ prompt = [prompt] if isinstance(prompt, str) else prompt
222
+ batch_size = len(prompt)
223
+
224
+ text_inputs = self.tokenizer(
225
+ prompt,
226
+ padding="max_length",
227
+ max_length=max_sequence_length,
228
+ truncation=True,
229
+ add_special_tokens=True,
230
+ return_tensors="pt",
231
+ )
232
+ text_input_ids = text_inputs.input_ids
233
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
234
+
235
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
236
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
237
+ logger.warning(
238
+ "The following part of your input was truncated because `max_sequence_length` is set to "
239
+ f" {max_sequence_length} tokens: {removed_text}"
240
+ )
241
+
242
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
243
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
244
+
245
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
246
+ _, seq_len, _ = prompt_embeds.shape
247
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
248
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
249
+
250
+ return prompt_embeds
251
+
252
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
253
+ def encode_prompt(
254
+ self,
255
+ prompt: Union[str, List[str]],
256
+ negative_prompt: Optional[Union[str, List[str]]] = None,
257
+ do_classifier_free_guidance: bool = True,
258
+ num_videos_per_prompt: int = 1,
259
+ prompt_embeds: Optional[torch.Tensor] = None,
260
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
261
+ max_sequence_length: int = 226,
262
+ device: Optional[torch.device] = None,
263
+ dtype: Optional[torch.dtype] = None,
264
+ ):
265
+ r"""
266
+ Encodes the prompt into text encoder hidden states.
267
+
268
+ Args:
269
+ prompt (`str` or `List[str]`, *optional*):
270
+ prompt to be encoded
271
+ negative_prompt (`str` or `List[str]`, *optional*):
272
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
273
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
274
+ less than `1`).
275
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
276
+ Whether to use classifier free guidance or not.
277
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
278
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
279
+ prompt_embeds (`torch.Tensor`, *optional*):
280
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
281
+ provided, text embeddings will be generated from `prompt` input argument.
282
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
283
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
284
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
285
+ argument.
286
+ device: (`torch.device`, *optional*):
287
+ torch device
288
+ dtype: (`torch.dtype`, *optional*):
289
+ torch dtype
290
+ """
291
+ device = device or self._execution_device
292
+
293
+ prompt = [prompt] if isinstance(prompt, str) else prompt
294
+ if prompt is not None:
295
+ batch_size = len(prompt)
296
+ else:
297
+ batch_size = prompt_embeds.shape[0]
298
+
299
+ if prompt_embeds is None:
300
+ prompt_embeds = self._get_t5_prompt_embeds(
301
+ prompt=prompt,
302
+ num_videos_per_prompt=num_videos_per_prompt,
303
+ max_sequence_length=max_sequence_length,
304
+ device=device,
305
+ dtype=dtype,
306
+ )
307
+
308
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
309
+ negative_prompt = negative_prompt or ""
310
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
311
+
312
+ if prompt is not None and type(prompt) is not type(negative_prompt):
313
+ raise TypeError(
314
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
315
+ f" {type(prompt)}."
316
+ )
317
+ elif batch_size != len(negative_prompt):
318
+ raise ValueError(
319
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
320
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
321
+ " the batch size of `prompt`."
322
+ )
323
+
324
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
325
+ prompt=negative_prompt,
326
+ num_videos_per_prompt=num_videos_per_prompt,
327
+ max_sequence_length=max_sequence_length,
328
+ device=device,
329
+ dtype=dtype,
330
+ )
331
+
332
+ return prompt_embeds, negative_prompt_embeds
333
+
334
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents
335
+ def prepare_latents(
336
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
337
+ ):
338
+ if isinstance(generator, list) and len(generator) != batch_size:
339
+ raise ValueError(
340
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
341
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
342
+ )
343
+
344
+ shape = (
345
+ batch_size,
346
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
347
+ num_channels_latents,
348
+ height // self.vae_scale_factor_spatial,
349
+ width // self.vae_scale_factor_spatial,
350
+ )
351
+
352
+ if latents is None:
353
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
354
+ else:
355
+ latents = latents.to(device)
356
+
357
+ # scale the initial noise by the standard deviation required by the scheduler
358
+ latents = latents * self.scheduler.init_noise_sigma
359
+ return latents
360
+
361
+ # Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366
362
+ def prepare_control_latents(
363
+ self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None
364
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
365
+ if mask is not None:
366
+ masks = []
367
+ for i in range(mask.size(0)):
368
+ current_mask = mask[i].unsqueeze(0)
369
+ current_mask = self.vae.encode(current_mask)[0]
370
+ current_mask = current_mask.mode()
371
+ masks.append(current_mask)
372
+ mask = torch.cat(masks, dim=0)
373
+ mask = mask * self.vae.config.scaling_factor
374
+
375
+ if masked_image is not None:
376
+ mask_pixel_values = []
377
+ for i in range(masked_image.size(0)):
378
+ mask_pixel_value = masked_image[i].unsqueeze(0)
379
+ mask_pixel_value = self.vae.encode(mask_pixel_value)[0]
380
+ mask_pixel_value = mask_pixel_value.mode()
381
+ mask_pixel_values.append(mask_pixel_value)
382
+ masked_image_latents = torch.cat(mask_pixel_values, dim=0)
383
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
384
+ else:
385
+ masked_image_latents = None
386
+
387
+ return mask, masked_image_latents
388
+
389
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
390
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
391
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
392
+ latents = 1 / self.vae_scaling_factor_image * latents
393
+
394
+ frames = self.vae.decode(latents).sample
395
+ return frames
396
+
397
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
398
+ def prepare_extra_step_kwargs(self, generator, eta):
399
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
400
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
401
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
402
+ # and should be between [0, 1]
403
+
404
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
405
+ extra_step_kwargs = {}
406
+ if accepts_eta:
407
+ extra_step_kwargs["eta"] = eta
408
+
409
+ # check if the scheduler accepts generator
410
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
411
+ if accepts_generator:
412
+ extra_step_kwargs["generator"] = generator
413
+ return extra_step_kwargs
414
+
415
+ def check_inputs(
416
+ self,
417
+ prompt,
418
+ height,
419
+ width,
420
+ negative_prompt,
421
+ callback_on_step_end_tensor_inputs,
422
+ prompt_embeds=None,
423
+ negative_prompt_embeds=None,
424
+ control_video=None,
425
+ control_video_latents=None,
426
+ ):
427
+ if height % 8 != 0 or width % 8 != 0:
428
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
429
+
430
+ if callback_on_step_end_tensor_inputs is not None and not all(
431
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
432
+ ):
433
+ raise ValueError(
434
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
435
+ )
436
+ if prompt is not None and prompt_embeds is not None:
437
+ raise ValueError(
438
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
439
+ " only forward one of the two."
440
+ )
441
+ elif prompt is None and prompt_embeds is None:
442
+ raise ValueError(
443
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
444
+ )
445
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
446
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
447
+
448
+ if prompt is not None and negative_prompt_embeds is not None:
449
+ raise ValueError(
450
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
451
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
452
+ )
453
+
454
+ if negative_prompt is not None and negative_prompt_embeds is not None:
455
+ raise ValueError(
456
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
457
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
458
+ )
459
+
460
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
461
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
462
+ raise ValueError(
463
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
464
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
465
+ f" {negative_prompt_embeds.shape}."
466
+ )
467
+
468
+ if control_video is not None and control_video_latents is not None:
469
+ raise ValueError(
470
+ "Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters."
471
+ )
472
+
473
+ def fuse_qkv_projections(self) -> None:
474
+ r"""Enables fused QKV projections."""
475
+ self.fusing_transformer = True
476
+ self.transformer.fuse_qkv_projections()
477
+
478
+ def unfuse_qkv_projections(self) -> None:
479
+ r"""Disable QKV projection fusion if enabled."""
480
+ if not self.fusing_transformer:
481
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
482
+ else:
483
+ self.transformer.unfuse_qkv_projections()
484
+ self.fusing_transformer = False
485
+
486
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
487
+ def _prepare_rotary_positional_embeddings(
488
+ self,
489
+ height: int,
490
+ width: int,
491
+ num_frames: int,
492
+ device: torch.device,
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
495
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
496
+
497
+ p = self.transformer.config.patch_size
498
+ p_t = self.transformer.config.patch_size_t
499
+
500
+ base_size_width = self.transformer.config.sample_width // p
501
+ base_size_height = self.transformer.config.sample_height // p
502
+
503
+ if p_t is None:
504
+ # CogVideoX 1.0
505
+ grid_crops_coords = get_resize_crop_region_for_grid(
506
+ (grid_height, grid_width), base_size_width, base_size_height
507
+ )
508
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
509
+ embed_dim=self.transformer.config.attention_head_dim,
510
+ crops_coords=grid_crops_coords,
511
+ grid_size=(grid_height, grid_width),
512
+ temporal_size=num_frames,
513
+ device=device,
514
+ )
515
+ else:
516
+ # CogVideoX 1.5
517
+ base_num_frames = (num_frames + p_t - 1) // p_t
518
+
519
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
520
+ embed_dim=self.transformer.config.attention_head_dim,
521
+ crops_coords=None,
522
+ grid_size=(grid_height, grid_width),
523
+ temporal_size=base_num_frames,
524
+ grid_type="slice",
525
+ max_size=(base_size_height, base_size_width),
526
+ device=device,
527
+ )
528
+
529
+ return freqs_cos, freqs_sin
530
+
531
+ @property
532
+ def guidance_scale(self):
533
+ return self._guidance_scale
534
+
535
+ @property
536
+ def num_timesteps(self):
537
+ return self._num_timesteps
538
+
539
+ @property
540
+ def attention_kwargs(self):
541
+ return self._attention_kwargs
542
+
543
+ @property
544
+ def current_timestep(self):
545
+ return self._current_timestep
546
+
547
+ @property
548
+ def interrupt(self):
549
+ return self._interrupt
550
+
551
+ @torch.no_grad()
552
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
553
+ def __call__(
554
+ self,
555
+ prompt: Optional[Union[str, List[str]]] = None,
556
+ negative_prompt: Optional[Union[str, List[str]]] = None,
557
+ control_video: Optional[List[Image.Image]] = None,
558
+ height: Optional[int] = None,
559
+ width: Optional[int] = None,
560
+ num_inference_steps: int = 50,
561
+ timesteps: Optional[List[int]] = None,
562
+ guidance_scale: float = 6,
563
+ use_dynamic_cfg: bool = False,
564
+ num_videos_per_prompt: int = 1,
565
+ eta: float = 0.0,
566
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
567
+ latents: Optional[torch.Tensor] = None,
568
+ control_video_latents: Optional[torch.Tensor] = None,
569
+ prompt_embeds: Optional[torch.Tensor] = None,
570
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
571
+ output_type: str = "pil",
572
+ return_dict: bool = True,
573
+ attention_kwargs: Optional[Dict[str, Any]] = None,
574
+ callback_on_step_end: Optional[
575
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
576
+ ] = None,
577
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
578
+ max_sequence_length: int = 226,
579
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
580
+ """
581
+ Function invoked when calling the pipeline for generation.
582
+
583
+ Args:
584
+ prompt (`str` or `List[str]`, *optional*):
585
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
586
+ instead.
587
+ negative_prompt (`str` or `List[str]`, *optional*):
588
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
589
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
590
+ less than `1`).
591
+ control_video (`List[PIL.Image.Image]`):
592
+ The control video to condition the generation on. Must be a list of images/frames of the video. If not
593
+ provided, `control_video_latents` must be provided.
594
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
595
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
596
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
597
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
598
+ num_inference_steps (`int`, *optional*, defaults to 50):
599
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
600
+ expense of slower inference.
601
+ timesteps (`List[int]`, *optional*):
602
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
603
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
604
+ passed will be used. Must be in descending order.
605
+ guidance_scale (`float`, *optional*, defaults to 6.0):
606
+ Guidance scale as defined in [Classifier-Free Diffusion
607
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
608
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
609
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
610
+ the text `prompt`, usually at the expense of lower image quality.
611
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
612
+ The number of videos to generate per prompt.
613
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
614
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
615
+ to make generation deterministic.
616
+ latents (`torch.Tensor`, *optional*):
617
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
618
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
619
+ tensor will be generated by sampling using the supplied random `generator`.
620
+ control_video_latents (`torch.Tensor`, *optional*):
621
+ Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
622
+ controlled video generation. If not provided, `control_video` must be provided.
623
+ prompt_embeds (`torch.Tensor`, *optional*):
624
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
625
+ provided, text embeddings will be generated from `prompt` input argument.
626
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
627
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
628
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
629
+ argument.
630
+ output_type (`str`, *optional*, defaults to `"pil"`):
631
+ The output format of the generate image. Choose between
632
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
633
+ return_dict (`bool`, *optional*, defaults to `True`):
634
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
635
+ of a plain tuple.
636
+ attention_kwargs (`dict`, *optional*):
637
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
638
+ `self.processor` in
639
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
640
+ callback_on_step_end (`Callable`, *optional*):
641
+ A function that calls at the end of each denoising steps during the inference. The function is called
642
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
643
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
644
+ `callback_on_step_end_tensor_inputs`.
645
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
646
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
647
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
648
+ `._callback_tensor_inputs` attribute of your pipeline class.
649
+ max_sequence_length (`int`, defaults to `226`):
650
+ Maximum sequence length in encoded prompt. Must be consistent with
651
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
652
+
653
+ Examples:
654
+
655
+ Returns:
656
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
657
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
658
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
659
+ """
660
+
661
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
662
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
663
+
664
+ if control_video is not None and isinstance(control_video[0], Image.Image):
665
+ control_video = [control_video]
666
+
667
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
668
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
669
+ num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
670
+
671
+ num_videos_per_prompt = 1
672
+
673
+ # 1. Check inputs. Raise error if not correct
674
+ self.check_inputs(
675
+ prompt,
676
+ height,
677
+ width,
678
+ negative_prompt,
679
+ callback_on_step_end_tensor_inputs,
680
+ prompt_embeds,
681
+ negative_prompt_embeds,
682
+ control_video,
683
+ control_video_latents,
684
+ )
685
+ self._guidance_scale = guidance_scale
686
+ self._attention_kwargs = attention_kwargs
687
+ self._current_timestep = None
688
+ self._interrupt = False
689
+
690
+ # 2. Default call parameters
691
+ if prompt is not None and isinstance(prompt, str):
692
+ batch_size = 1
693
+ elif prompt is not None and isinstance(prompt, list):
694
+ batch_size = len(prompt)
695
+ else:
696
+ batch_size = prompt_embeds.shape[0]
697
+
698
+ device = self._execution_device
699
+
700
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
701
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
702
+ # corresponds to doing no classifier free guidance.
703
+ do_classifier_free_guidance = guidance_scale > 1.0
704
+
705
+ # 3. Encode input prompt
706
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
707
+ prompt,
708
+ negative_prompt,
709
+ do_classifier_free_guidance,
710
+ num_videos_per_prompt=num_videos_per_prompt,
711
+ prompt_embeds=prompt_embeds,
712
+ negative_prompt_embeds=negative_prompt_embeds,
713
+ max_sequence_length=max_sequence_length,
714
+ device=device,
715
+ )
716
+ if do_classifier_free_guidance:
717
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
718
+
719
+ # 4. Prepare timesteps
720
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
721
+ self._num_timesteps = len(timesteps)
722
+
723
+ # 5. Prepare latents
724
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
725
+
726
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
727
+ patch_size_t = self.transformer.config.patch_size_t
728
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
729
+ raise ValueError(
730
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
731
+ f"contains {latent_frames=}, which is not divisible."
732
+ )
733
+
734
+ latent_channels = self.transformer.config.in_channels // 2
735
+ latents = self.prepare_latents(
736
+ batch_size * num_videos_per_prompt,
737
+ latent_channels,
738
+ num_frames,
739
+ height,
740
+ width,
741
+ prompt_embeds.dtype,
742
+ device,
743
+ generator,
744
+ latents,
745
+ )
746
+
747
+ if control_video_latents is None:
748
+ control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
749
+ control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
750
+
751
+ _, control_video_latents = self.prepare_control_latents(None, control_video)
752
+ control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
753
+
754
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
755
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
756
+
757
+ # 7. Create rotary embeds if required
758
+ image_rotary_emb = (
759
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
760
+ if self.transformer.config.use_rotary_positional_embeddings
761
+ else None
762
+ )
763
+
764
+ # 8. Denoising loop
765
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
766
+
767
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
768
+ # for DPM-solver++
769
+ old_pred_original_sample = None
770
+ for i, t in enumerate(timesteps):
771
+ if self.interrupt:
772
+ continue
773
+
774
+ self._current_timestep = t
775
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
776
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
777
+
778
+ latent_control_input = (
779
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
780
+ )
781
+ latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2)
782
+
783
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
784
+ timestep = t.expand(latent_model_input.shape[0])
785
+
786
+ # predict noise model_output
787
+ with self.transformer.cache_context("cond_uncond"):
788
+ noise_pred = self.transformer(
789
+ hidden_states=latent_model_input,
790
+ encoder_hidden_states=prompt_embeds,
791
+ timestep=timestep,
792
+ image_rotary_emb=image_rotary_emb,
793
+ attention_kwargs=attention_kwargs,
794
+ return_dict=False,
795
+ )[0]
796
+ noise_pred = noise_pred.float()
797
+
798
+ # perform guidance
799
+ if use_dynamic_cfg:
800
+ self._guidance_scale = 1 + guidance_scale * (
801
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
802
+ )
803
+ if do_classifier_free_guidance:
804
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
805
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
806
+
807
+ # compute the previous noisy sample x_t -> x_t-1
808
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
809
+ latents = latents.to(prompt_embeds.dtype)
810
+
811
+ # call the callback, if provided
812
+ if callback_on_step_end is not None:
813
+ callback_kwargs = {}
814
+ for k in callback_on_step_end_tensor_inputs:
815
+ callback_kwargs[k] = locals()[k]
816
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
817
+
818
+ latents = callback_outputs.pop("latents", latents)
819
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
820
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
821
+
822
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
823
+ progress_bar.update()
824
+
825
+ if XLA_AVAILABLE:
826
+ xm.mark_step()
827
+
828
+ self._current_timestep = None
829
+
830
+ if not output_type == "latent":
831
+ video = self.decode_latents(latents)
832
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
833
+ else:
834
+ video = latents
835
+
836
+ # Offload all models
837
+ self.maybe_free_model_hooks()
838
+
839
+ if not return_dict:
840
+ return (video,)
841
+
842
+ return CogVideoXPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import PIL
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...image_processor import PipelineImageInput
26
+ from ...loaders import CogVideoXLoraLoaderMixin
27
+ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from ...models.embeddings import get_3d_rotary_pos_embed
29
+ from ...pipelines.pipeline_utils import DiffusionPipeline
30
+ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from ...utils import (
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
+ from ...utils.torch_utils import randn_tensor
37
+ from ...video_processor import VideoProcessor
38
+ from .pipeline_output import CogVideoXPipelineOutput
39
+
40
+
41
+ if is_torch_xla_available():
42
+ import torch_xla.core.xla_model as xm
43
+
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```py
54
+ >>> import torch
55
+ >>> from diffusers import CogVideoXImageToVideoPipeline
56
+ >>> from diffusers.utils import export_to_video, load_image
57
+
58
+ >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
59
+ >>> pipe.to("cuda")
60
+
61
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
62
+ >>> image = load_image(
63
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
64
+ ... )
65
+ >>> video = pipe(image, prompt, use_dynamic_cfg=True)
66
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
67
+ ```
68
+ """
69
+
70
+
71
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
72
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
73
+ tw = tgt_width
74
+ th = tgt_height
75
+ h, w = src
76
+ r = h / w
77
+ if r > (th / tw):
78
+ resize_height = th
79
+ resize_width = int(round(th / h * w))
80
+ else:
81
+ resize_width = tw
82
+ resize_height = int(round(tw / w * h))
83
+
84
+ crop_top = int(round((th - resize_height) / 2.0))
85
+ crop_left = int(round((tw - resize_width) / 2.0))
86
+
87
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ r"""
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
151
+ def retrieve_latents(
152
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
153
+ ):
154
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
155
+ return encoder_output.latent_dist.sample(generator)
156
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
157
+ return encoder_output.latent_dist.mode()
158
+ elif hasattr(encoder_output, "latents"):
159
+ return encoder_output.latents
160
+ else:
161
+ raise AttributeError("Could not access latents of provided encoder_output")
162
+
163
+
164
+ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
165
+ r"""
166
+ Pipeline for image-to-video generation using CogVideoX.
167
+
168
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
169
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
170
+
171
+ Args:
172
+ vae ([`AutoencoderKL`]):
173
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
174
+ text_encoder ([`T5EncoderModel`]):
175
+ Frozen text-encoder. CogVideoX uses
176
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
177
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
178
+ tokenizer (`T5Tokenizer`):
179
+ Tokenizer of class
180
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
181
+ transformer ([`CogVideoXTransformer3DModel`]):
182
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
183
+ scheduler ([`SchedulerMixin`]):
184
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
185
+ """
186
+
187
+ _optional_components = []
188
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
189
+
190
+ _callback_tensor_inputs = [
191
+ "latents",
192
+ "prompt_embeds",
193
+ "negative_prompt_embeds",
194
+ ]
195
+
196
+ def __init__(
197
+ self,
198
+ tokenizer: T5Tokenizer,
199
+ text_encoder: T5EncoderModel,
200
+ vae: AutoencoderKLCogVideoX,
201
+ transformer: CogVideoXTransformer3DModel,
202
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
203
+ ):
204
+ super().__init__()
205
+
206
+ self.register_modules(
207
+ tokenizer=tokenizer,
208
+ text_encoder=text_encoder,
209
+ vae=vae,
210
+ transformer=transformer,
211
+ scheduler=scheduler,
212
+ )
213
+ self.vae_scale_factor_spatial = (
214
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
215
+ )
216
+ self.vae_scale_factor_temporal = (
217
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
218
+ )
219
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
220
+
221
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
222
+
223
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
224
+ def _get_t5_prompt_embeds(
225
+ self,
226
+ prompt: Union[str, List[str]] = None,
227
+ num_videos_per_prompt: int = 1,
228
+ max_sequence_length: int = 226,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ ):
232
+ device = device or self._execution_device
233
+ dtype = dtype or self.text_encoder.dtype
234
+
235
+ prompt = [prompt] if isinstance(prompt, str) else prompt
236
+ batch_size = len(prompt)
237
+
238
+ text_inputs = self.tokenizer(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ add_special_tokens=True,
244
+ return_tensors="pt",
245
+ )
246
+ text_input_ids = text_inputs.input_ids
247
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
248
+
249
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
250
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
251
+ logger.warning(
252
+ "The following part of your input was truncated because `max_sequence_length` is set to "
253
+ f" {max_sequence_length} tokens: {removed_text}"
254
+ )
255
+
256
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
257
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
258
+
259
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
260
+ _, seq_len, _ = prompt_embeds.shape
261
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
262
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
263
+
264
+ return prompt_embeds
265
+
266
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
267
+ def encode_prompt(
268
+ self,
269
+ prompt: Union[str, List[str]],
270
+ negative_prompt: Optional[Union[str, List[str]]] = None,
271
+ do_classifier_free_guidance: bool = True,
272
+ num_videos_per_prompt: int = 1,
273
+ prompt_embeds: Optional[torch.Tensor] = None,
274
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
275
+ max_sequence_length: int = 226,
276
+ device: Optional[torch.device] = None,
277
+ dtype: Optional[torch.dtype] = None,
278
+ ):
279
+ r"""
280
+ Encodes the prompt into text encoder hidden states.
281
+
282
+ Args:
283
+ prompt (`str` or `List[str]`, *optional*):
284
+ prompt to be encoded
285
+ negative_prompt (`str` or `List[str]`, *optional*):
286
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
287
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
288
+ less than `1`).
289
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
290
+ Whether to use classifier free guidance or not.
291
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
292
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
293
+ prompt_embeds (`torch.Tensor`, *optional*):
294
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
295
+ provided, text embeddings will be generated from `prompt` input argument.
296
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
297
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
298
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
299
+ argument.
300
+ device: (`torch.device`, *optional*):
301
+ torch device
302
+ dtype: (`torch.dtype`, *optional*):
303
+ torch dtype
304
+ """
305
+ device = device or self._execution_device
306
+
307
+ prompt = [prompt] if isinstance(prompt, str) else prompt
308
+ if prompt is not None:
309
+ batch_size = len(prompt)
310
+ else:
311
+ batch_size = prompt_embeds.shape[0]
312
+
313
+ if prompt_embeds is None:
314
+ prompt_embeds = self._get_t5_prompt_embeds(
315
+ prompt=prompt,
316
+ num_videos_per_prompt=num_videos_per_prompt,
317
+ max_sequence_length=max_sequence_length,
318
+ device=device,
319
+ dtype=dtype,
320
+ )
321
+
322
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
323
+ negative_prompt = negative_prompt or ""
324
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
325
+
326
+ if prompt is not None and type(prompt) is not type(negative_prompt):
327
+ raise TypeError(
328
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
329
+ f" {type(prompt)}."
330
+ )
331
+ elif batch_size != len(negative_prompt):
332
+ raise ValueError(
333
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
334
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
335
+ " the batch size of `prompt`."
336
+ )
337
+
338
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
339
+ prompt=negative_prompt,
340
+ num_videos_per_prompt=num_videos_per_prompt,
341
+ max_sequence_length=max_sequence_length,
342
+ device=device,
343
+ dtype=dtype,
344
+ )
345
+
346
+ return prompt_embeds, negative_prompt_embeds
347
+
348
+ def prepare_latents(
349
+ self,
350
+ image: torch.Tensor,
351
+ batch_size: int = 1,
352
+ num_channels_latents: int = 16,
353
+ num_frames: int = 13,
354
+ height: int = 60,
355
+ width: int = 90,
356
+ dtype: Optional[torch.dtype] = None,
357
+ device: Optional[torch.device] = None,
358
+ generator: Optional[torch.Generator] = None,
359
+ latents: Optional[torch.Tensor] = None,
360
+ ):
361
+ if isinstance(generator, list) and len(generator) != batch_size:
362
+ raise ValueError(
363
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
364
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
365
+ )
366
+
367
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
368
+ shape = (
369
+ batch_size,
370
+ num_frames,
371
+ num_channels_latents,
372
+ height // self.vae_scale_factor_spatial,
373
+ width // self.vae_scale_factor_spatial,
374
+ )
375
+
376
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
377
+ if self.transformer.config.patch_size_t is not None:
378
+ shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
379
+
380
+ image = image.unsqueeze(2) # [B, C, F, H, W]
381
+
382
+ if isinstance(generator, list):
383
+ image_latents = [
384
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
385
+ ]
386
+ else:
387
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
388
+
389
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
390
+
391
+ if not self.vae.config.invert_scale_latents:
392
+ image_latents = self.vae_scaling_factor_image * image_latents
393
+ else:
394
+ # This is awkward but required because the CogVideoX team forgot to multiply the
395
+ # scaling factor during training :)
396
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
397
+
398
+ padding_shape = (
399
+ batch_size,
400
+ num_frames - 1,
401
+ num_channels_latents,
402
+ height // self.vae_scale_factor_spatial,
403
+ width // self.vae_scale_factor_spatial,
404
+ )
405
+
406
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
407
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
408
+
409
+ # Select the first frame along the second dimension
410
+ if self.transformer.config.patch_size_t is not None:
411
+ first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
412
+ image_latents = torch.cat([first_frame, image_latents], dim=1)
413
+
414
+ if latents is None:
415
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
416
+ else:
417
+ latents = latents.to(device)
418
+
419
+ # scale the initial noise by the standard deviation required by the scheduler
420
+ latents = latents * self.scheduler.init_noise_sigma
421
+ return latents, image_latents
422
+
423
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
424
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
425
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
426
+ latents = 1 / self.vae_scaling_factor_image * latents
427
+
428
+ frames = self.vae.decode(latents).sample
429
+ return frames
430
+
431
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
432
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
433
+ # get the original timestep using init_timestep
434
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
435
+
436
+ t_start = max(num_inference_steps - init_timestep, 0)
437
+ timesteps = timesteps[t_start * self.scheduler.order :]
438
+
439
+ return timesteps, num_inference_steps - t_start
440
+
441
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
442
+ def prepare_extra_step_kwargs(self, generator, eta):
443
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
444
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
445
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
446
+ # and should be between [0, 1]
447
+
448
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
449
+ extra_step_kwargs = {}
450
+ if accepts_eta:
451
+ extra_step_kwargs["eta"] = eta
452
+
453
+ # check if the scheduler accepts generator
454
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
455
+ if accepts_generator:
456
+ extra_step_kwargs["generator"] = generator
457
+ return extra_step_kwargs
458
+
459
+ def check_inputs(
460
+ self,
461
+ image,
462
+ prompt,
463
+ height,
464
+ width,
465
+ negative_prompt,
466
+ callback_on_step_end_tensor_inputs,
467
+ latents=None,
468
+ prompt_embeds=None,
469
+ negative_prompt_embeds=None,
470
+ ):
471
+ if (
472
+ not isinstance(image, torch.Tensor)
473
+ and not isinstance(image, PIL.Image.Image)
474
+ and not isinstance(image, list)
475
+ ):
476
+ raise ValueError(
477
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
478
+ f" {type(image)}"
479
+ )
480
+
481
+ if height % 8 != 0 or width % 8 != 0:
482
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
483
+
484
+ if callback_on_step_end_tensor_inputs is not None and not all(
485
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
486
+ ):
487
+ raise ValueError(
488
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
489
+ )
490
+ if prompt is not None and prompt_embeds is not None:
491
+ raise ValueError(
492
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
493
+ " only forward one of the two."
494
+ )
495
+ elif prompt is None and prompt_embeds is None:
496
+ raise ValueError(
497
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
498
+ )
499
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
500
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
501
+
502
+ if prompt is not None and negative_prompt_embeds is not None:
503
+ raise ValueError(
504
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
505
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
506
+ )
507
+
508
+ if negative_prompt is not None and negative_prompt_embeds is not None:
509
+ raise ValueError(
510
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
511
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
512
+ )
513
+
514
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
515
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
516
+ raise ValueError(
517
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
518
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
519
+ f" {negative_prompt_embeds.shape}."
520
+ )
521
+
522
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
523
+ def fuse_qkv_projections(self) -> None:
524
+ r"""Enables fused QKV projections."""
525
+ self.fusing_transformer = True
526
+ self.transformer.fuse_qkv_projections()
527
+
528
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
529
+ def unfuse_qkv_projections(self) -> None:
530
+ r"""Disable QKV projection fusion if enabled."""
531
+ if not self.fusing_transformer:
532
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
533
+ else:
534
+ self.transformer.unfuse_qkv_projections()
535
+ self.fusing_transformer = False
536
+
537
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
538
+ def _prepare_rotary_positional_embeddings(
539
+ self,
540
+ height: int,
541
+ width: int,
542
+ num_frames: int,
543
+ device: torch.device,
544
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
545
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
546
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
547
+
548
+ p = self.transformer.config.patch_size
549
+ p_t = self.transformer.config.patch_size_t
550
+
551
+ base_size_width = self.transformer.config.sample_width // p
552
+ base_size_height = self.transformer.config.sample_height // p
553
+
554
+ if p_t is None:
555
+ # CogVideoX 1.0
556
+ grid_crops_coords = get_resize_crop_region_for_grid(
557
+ (grid_height, grid_width), base_size_width, base_size_height
558
+ )
559
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
560
+ embed_dim=self.transformer.config.attention_head_dim,
561
+ crops_coords=grid_crops_coords,
562
+ grid_size=(grid_height, grid_width),
563
+ temporal_size=num_frames,
564
+ device=device,
565
+ )
566
+ else:
567
+ # CogVideoX 1.5
568
+ base_num_frames = (num_frames + p_t - 1) // p_t
569
+
570
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
571
+ embed_dim=self.transformer.config.attention_head_dim,
572
+ crops_coords=None,
573
+ grid_size=(grid_height, grid_width),
574
+ temporal_size=base_num_frames,
575
+ grid_type="slice",
576
+ max_size=(base_size_height, base_size_width),
577
+ device=device,
578
+ )
579
+
580
+ return freqs_cos, freqs_sin
581
+
582
+ @property
583
+ def guidance_scale(self):
584
+ return self._guidance_scale
585
+
586
+ @property
587
+ def num_timesteps(self):
588
+ return self._num_timesteps
589
+
590
+ @property
591
+ def attention_kwargs(self):
592
+ return self._attention_kwargs
593
+
594
+ @property
595
+ def current_timestep(self):
596
+ return self._current_timestep
597
+
598
+ @property
599
+ def interrupt(self):
600
+ return self._interrupt
601
+
602
+ @torch.no_grad()
603
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
604
+ def __call__(
605
+ self,
606
+ image: PipelineImageInput,
607
+ prompt: Optional[Union[str, List[str]]] = None,
608
+ negative_prompt: Optional[Union[str, List[str]]] = None,
609
+ height: Optional[int] = None,
610
+ width: Optional[int] = None,
611
+ num_frames: int = 49,
612
+ num_inference_steps: int = 50,
613
+ timesteps: Optional[List[int]] = None,
614
+ guidance_scale: float = 6,
615
+ use_dynamic_cfg: bool = False,
616
+ num_videos_per_prompt: int = 1,
617
+ eta: float = 0.0,
618
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
619
+ latents: Optional[torch.FloatTensor] = None,
620
+ prompt_embeds: Optional[torch.FloatTensor] = None,
621
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
622
+ output_type: str = "pil",
623
+ return_dict: bool = True,
624
+ attention_kwargs: Optional[Dict[str, Any]] = None,
625
+ callback_on_step_end: Optional[
626
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
627
+ ] = None,
628
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
629
+ max_sequence_length: int = 226,
630
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
631
+ """
632
+ Function invoked when calling the pipeline for generation.
633
+
634
+ Args:
635
+ image (`PipelineImageInput`):
636
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
637
+ prompt (`str` or `List[str]`, *optional*):
638
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
639
+ instead.
640
+ negative_prompt (`str` or `List[str]`, *optional*):
641
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
642
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
643
+ less than `1`).
644
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
645
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
646
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
647
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
648
+ num_frames (`int`, defaults to `48`):
649
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
650
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
651
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
652
+ needs to be satisfied is that of divisibility mentioned above.
653
+ num_inference_steps (`int`, *optional*, defaults to 50):
654
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
655
+ expense of slower inference.
656
+ timesteps (`List[int]`, *optional*):
657
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
658
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
659
+ passed will be used. Must be in descending order.
660
+ guidance_scale (`float`, *optional*, defaults to 7.0):
661
+ Guidance scale as defined in [Classifier-Free Diffusion
662
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
663
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
664
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
665
+ the text `prompt`, usually at the expense of lower image quality.
666
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
667
+ The number of videos to generate per prompt.
668
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
669
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
670
+ to make generation deterministic.
671
+ latents (`torch.FloatTensor`, *optional*):
672
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
673
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
674
+ tensor will be generated by sampling using the supplied random `generator`.
675
+ prompt_embeds (`torch.FloatTensor`, *optional*):
676
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
677
+ provided, text embeddings will be generated from `prompt` input argument.
678
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
679
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
680
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
681
+ argument.
682
+ output_type (`str`, *optional*, defaults to `"pil"`):
683
+ The output format of the generate image. Choose between
684
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
685
+ return_dict (`bool`, *optional*, defaults to `True`):
686
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
687
+ of a plain tuple.
688
+ attention_kwargs (`dict`, *optional*):
689
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
690
+ `self.processor` in
691
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
692
+ callback_on_step_end (`Callable`, *optional*):
693
+ A function that calls at the end of each denoising steps during the inference. The function is called
694
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
695
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
696
+ `callback_on_step_end_tensor_inputs`.
697
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
698
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
699
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
700
+ `._callback_tensor_inputs` attribute of your pipeline class.
701
+ max_sequence_length (`int`, defaults to `226`):
702
+ Maximum sequence length in encoded prompt. Must be consistent with
703
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
704
+
705
+ Examples:
706
+
707
+ Returns:
708
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
709
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
710
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
711
+ """
712
+
713
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
714
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
715
+
716
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
717
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
718
+ num_frames = num_frames or self.transformer.config.sample_frames
719
+
720
+ num_videos_per_prompt = 1
721
+
722
+ # 1. Check inputs. Raise error if not correct
723
+ self.check_inputs(
724
+ image=image,
725
+ prompt=prompt,
726
+ height=height,
727
+ width=width,
728
+ negative_prompt=negative_prompt,
729
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
730
+ latents=latents,
731
+ prompt_embeds=prompt_embeds,
732
+ negative_prompt_embeds=negative_prompt_embeds,
733
+ )
734
+ self._guidance_scale = guidance_scale
735
+ self._current_timestep = None
736
+ self._attention_kwargs = attention_kwargs
737
+ self._interrupt = False
738
+
739
+ # 2. Default call parameters
740
+ if prompt is not None and isinstance(prompt, str):
741
+ batch_size = 1
742
+ elif prompt is not None and isinstance(prompt, list):
743
+ batch_size = len(prompt)
744
+ else:
745
+ batch_size = prompt_embeds.shape[0]
746
+
747
+ device = self._execution_device
748
+
749
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
750
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
751
+ # corresponds to doing no classifier free guidance.
752
+ do_classifier_free_guidance = guidance_scale > 1.0
753
+
754
+ # 3. Encode input prompt
755
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
756
+ prompt=prompt,
757
+ negative_prompt=negative_prompt,
758
+ do_classifier_free_guidance=do_classifier_free_guidance,
759
+ num_videos_per_prompt=num_videos_per_prompt,
760
+ prompt_embeds=prompt_embeds,
761
+ negative_prompt_embeds=negative_prompt_embeds,
762
+ max_sequence_length=max_sequence_length,
763
+ device=device,
764
+ )
765
+ if do_classifier_free_guidance:
766
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
767
+
768
+ # 4. Prepare timesteps
769
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
770
+ self._num_timesteps = len(timesteps)
771
+
772
+ # 5. Prepare latents
773
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
774
+
775
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
776
+ patch_size_t = self.transformer.config.patch_size_t
777
+ additional_frames = 0
778
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
779
+ additional_frames = patch_size_t - latent_frames % patch_size_t
780
+ num_frames += additional_frames * self.vae_scale_factor_temporal
781
+
782
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
783
+ device, dtype=prompt_embeds.dtype
784
+ )
785
+
786
+ latent_channels = self.transformer.config.in_channels // 2
787
+ latents, image_latents = self.prepare_latents(
788
+ image,
789
+ batch_size * num_videos_per_prompt,
790
+ latent_channels,
791
+ num_frames,
792
+ height,
793
+ width,
794
+ prompt_embeds.dtype,
795
+ device,
796
+ generator,
797
+ latents,
798
+ )
799
+
800
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
801
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
802
+
803
+ # 7. Create rotary embeds if required
804
+ image_rotary_emb = (
805
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
806
+ if self.transformer.config.use_rotary_positional_embeddings
807
+ else None
808
+ )
809
+
810
+ # 8. Create ofs embeds if required
811
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
812
+
813
+ # 8. Denoising loop
814
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
815
+
816
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
817
+ # for DPM-solver++
818
+ old_pred_original_sample = None
819
+ for i, t in enumerate(timesteps):
820
+ if self.interrupt:
821
+ continue
822
+
823
+ self._current_timestep = t
824
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
825
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
826
+
827
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
828
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
829
+
830
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
831
+ timestep = t.expand(latent_model_input.shape[0])
832
+
833
+ # predict noise model_output
834
+ with self.transformer.cache_context("cond_uncond"):
835
+ noise_pred = self.transformer(
836
+ hidden_states=latent_model_input,
837
+ encoder_hidden_states=prompt_embeds,
838
+ timestep=timestep,
839
+ ofs=ofs_emb,
840
+ image_rotary_emb=image_rotary_emb,
841
+ attention_kwargs=attention_kwargs,
842
+ return_dict=False,
843
+ )[0]
844
+ noise_pred = noise_pred.float()
845
+
846
+ # perform guidance
847
+ if use_dynamic_cfg:
848
+ self._guidance_scale = 1 + guidance_scale * (
849
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
850
+ )
851
+ if do_classifier_free_guidance:
852
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
853
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
854
+
855
+ # compute the previous noisy sample x_t -> x_t-1
856
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
857
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
858
+ else:
859
+ latents, old_pred_original_sample = self.scheduler.step(
860
+ noise_pred,
861
+ old_pred_original_sample,
862
+ t,
863
+ timesteps[i - 1] if i > 0 else None,
864
+ latents,
865
+ **extra_step_kwargs,
866
+ return_dict=False,
867
+ )
868
+ latents = latents.to(prompt_embeds.dtype)
869
+
870
+ # call the callback, if provided
871
+ if callback_on_step_end is not None:
872
+ callback_kwargs = {}
873
+ for k in callback_on_step_end_tensor_inputs:
874
+ callback_kwargs[k] = locals()[k]
875
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
876
+
877
+ latents = callback_outputs.pop("latents", latents)
878
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
879
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
880
+
881
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
882
+ progress_bar.update()
883
+
884
+ if XLA_AVAILABLE:
885
+ xm.mark_step()
886
+
887
+ self._current_timestep = None
888
+
889
+ if not output_type == "latent":
890
+ # Discard any padding frames that were added for CogVideoX 1.5
891
+ latents = latents[:, additional_frames:]
892
+ video = self.decode_latents(latents)
893
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
894
+ else:
895
+ video = latents
896
+
897
+ # Offload all models
898
+ self.maybe_free_model_hooks()
899
+
900
+ if not return_dict:
901
+ return (video,)
902
+
903
+ return CogVideoXPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from PIL import Image
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...loaders import CogVideoXLoraLoaderMixin
26
+ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
27
+ from ...models.embeddings import get_3d_rotary_pos_embed
28
+ from ...pipelines.pipeline_utils import DiffusionPipeline
29
+ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
30
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
31
+ from ...utils.torch_utils import randn_tensor
32
+ from ...video_processor import VideoProcessor
33
+ from .pipeline_output import CogVideoXPipelineOutput
34
+
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ ```python
49
+ >>> import torch
50
+ >>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline
51
+ >>> from diffusers.utils import export_to_video, load_video
52
+
53
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
54
+ >>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
55
+ >>> pipe.to("cuda")
56
+ >>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
57
+
58
+ >>> input_video = load_video(
59
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
60
+ ... )
61
+ >>> prompt = (
62
+ ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
63
+ ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
64
+ ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
65
+ ... "moons, but the remainder of the scene is mostly realistic."
66
+ ... )
67
+
68
+ >>> video = pipe(
69
+ ... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50
70
+ ... ).frames[0]
71
+ >>> export_to_video(video, "output.mp4", fps=8)
72
+ ```
73
+ """
74
+
75
+
76
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
77
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
78
+ tw = tgt_width
79
+ th = tgt_height
80
+ h, w = src
81
+ r = h / w
82
+ if r > (th / tw):
83
+ resize_height = th
84
+ resize_width = int(round(th / h * w))
85
+ else:
86
+ resize_width = tw
87
+ resize_height = int(round(tw / w * h))
88
+
89
+ crop_top = int(round((th - resize_height) / 2.0))
90
+ crop_left = int(round((tw - resize_width) / 2.0))
91
+
92
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
93
+
94
+
95
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
96
+ def retrieve_timesteps(
97
+ scheduler,
98
+ num_inference_steps: Optional[int] = None,
99
+ device: Optional[Union[str, torch.device]] = None,
100
+ timesteps: Optional[List[int]] = None,
101
+ sigmas: Optional[List[float]] = None,
102
+ **kwargs,
103
+ ):
104
+ r"""
105
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
+
108
+ Args:
109
+ scheduler (`SchedulerMixin`):
110
+ The scheduler to get timesteps from.
111
+ num_inference_steps (`int`):
112
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
113
+ must be `None`.
114
+ device (`str` or `torch.device`, *optional*):
115
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
116
+ timesteps (`List[int]`, *optional*):
117
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
118
+ `num_inference_steps` and `sigmas` must be `None`.
119
+ sigmas (`List[float]`, *optional*):
120
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
121
+ `num_inference_steps` and `timesteps` must be `None`.
122
+
123
+ Returns:
124
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
125
+ second element is the number of inference steps.
126
+ """
127
+ if timesteps is not None and sigmas is not None:
128
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
129
+ if timesteps is not None:
130
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
131
+ if not accepts_timesteps:
132
+ raise ValueError(
133
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
134
+ f" timestep schedules. Please check whether you are using the correct scheduler."
135
+ )
136
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
137
+ timesteps = scheduler.timesteps
138
+ num_inference_steps = len(timesteps)
139
+ elif sigmas is not None:
140
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accept_sigmas:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
149
+ else:
150
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ return timesteps, num_inference_steps
153
+
154
+
155
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
156
+ def retrieve_latents(
157
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
158
+ ):
159
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
160
+ return encoder_output.latent_dist.sample(generator)
161
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
162
+ return encoder_output.latent_dist.mode()
163
+ elif hasattr(encoder_output, "latents"):
164
+ return encoder_output.latents
165
+ else:
166
+ raise AttributeError("Could not access latents of provided encoder_output")
167
+
168
+
169
+ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
170
+ r"""
171
+ Pipeline for video-to-video generation using CogVideoX.
172
+
173
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
174
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
175
+
176
+ Args:
177
+ vae ([`AutoencoderKL`]):
178
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
179
+ text_encoder ([`T5EncoderModel`]):
180
+ Frozen text-encoder. CogVideoX uses
181
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
182
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
183
+ tokenizer (`T5Tokenizer`):
184
+ Tokenizer of class
185
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
186
+ transformer ([`CogVideoXTransformer3DModel`]):
187
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
188
+ scheduler ([`SchedulerMixin`]):
189
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
190
+ """
191
+
192
+ _optional_components = []
193
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
194
+
195
+ _callback_tensor_inputs = [
196
+ "latents",
197
+ "prompt_embeds",
198
+ "negative_prompt_embeds",
199
+ ]
200
+
201
+ def __init__(
202
+ self,
203
+ tokenizer: T5Tokenizer,
204
+ text_encoder: T5EncoderModel,
205
+ vae: AutoencoderKLCogVideoX,
206
+ transformer: CogVideoXTransformer3DModel,
207
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
213
+ )
214
+
215
+ self.vae_scale_factor_spatial = (
216
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
217
+ )
218
+ self.vae_scale_factor_temporal = (
219
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
220
+ )
221
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
222
+
223
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
224
+
225
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
226
+ def _get_t5_prompt_embeds(
227
+ self,
228
+ prompt: Union[str, List[str]] = None,
229
+ num_videos_per_prompt: int = 1,
230
+ max_sequence_length: int = 226,
231
+ device: Optional[torch.device] = None,
232
+ dtype: Optional[torch.dtype] = None,
233
+ ):
234
+ device = device or self._execution_device
235
+ dtype = dtype or self.text_encoder.dtype
236
+
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+ batch_size = len(prompt)
239
+
240
+ text_inputs = self.tokenizer(
241
+ prompt,
242
+ padding="max_length",
243
+ max_length=max_sequence_length,
244
+ truncation=True,
245
+ add_special_tokens=True,
246
+ return_tensors="pt",
247
+ )
248
+ text_input_ids = text_inputs.input_ids
249
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
250
+
251
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
252
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
253
+ logger.warning(
254
+ "The following part of your input was truncated because `max_sequence_length` is set to "
255
+ f" {max_sequence_length} tokens: {removed_text}"
256
+ )
257
+
258
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
259
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
260
+
261
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
262
+ _, seq_len, _ = prompt_embeds.shape
263
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
264
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
265
+
266
+ return prompt_embeds
267
+
268
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
269
+ def encode_prompt(
270
+ self,
271
+ prompt: Union[str, List[str]],
272
+ negative_prompt: Optional[Union[str, List[str]]] = None,
273
+ do_classifier_free_guidance: bool = True,
274
+ num_videos_per_prompt: int = 1,
275
+ prompt_embeds: Optional[torch.Tensor] = None,
276
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
277
+ max_sequence_length: int = 226,
278
+ device: Optional[torch.device] = None,
279
+ dtype: Optional[torch.dtype] = None,
280
+ ):
281
+ r"""
282
+ Encodes the prompt into text encoder hidden states.
283
+
284
+ Args:
285
+ prompt (`str` or `List[str]`, *optional*):
286
+ prompt to be encoded
287
+ negative_prompt (`str` or `List[str]`, *optional*):
288
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
289
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
290
+ less than `1`).
291
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
292
+ Whether to use classifier free guidance or not.
293
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
294
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
295
+ prompt_embeds (`torch.Tensor`, *optional*):
296
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
297
+ provided, text embeddings will be generated from `prompt` input argument.
298
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
299
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
300
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
301
+ argument.
302
+ device: (`torch.device`, *optional*):
303
+ torch device
304
+ dtype: (`torch.dtype`, *optional*):
305
+ torch dtype
306
+ """
307
+ device = device or self._execution_device
308
+
309
+ prompt = [prompt] if isinstance(prompt, str) else prompt
310
+ if prompt is not None:
311
+ batch_size = len(prompt)
312
+ else:
313
+ batch_size = prompt_embeds.shape[0]
314
+
315
+ if prompt_embeds is None:
316
+ prompt_embeds = self._get_t5_prompt_embeds(
317
+ prompt=prompt,
318
+ num_videos_per_prompt=num_videos_per_prompt,
319
+ max_sequence_length=max_sequence_length,
320
+ device=device,
321
+ dtype=dtype,
322
+ )
323
+
324
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
325
+ negative_prompt = negative_prompt or ""
326
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
327
+
328
+ if prompt is not None and type(prompt) is not type(negative_prompt):
329
+ raise TypeError(
330
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
331
+ f" {type(prompt)}."
332
+ )
333
+ elif batch_size != len(negative_prompt):
334
+ raise ValueError(
335
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
336
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
337
+ " the batch size of `prompt`."
338
+ )
339
+
340
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
341
+ prompt=negative_prompt,
342
+ num_videos_per_prompt=num_videos_per_prompt,
343
+ max_sequence_length=max_sequence_length,
344
+ device=device,
345
+ dtype=dtype,
346
+ )
347
+
348
+ return prompt_embeds, negative_prompt_embeds
349
+
350
+ def prepare_latents(
351
+ self,
352
+ video: Optional[torch.Tensor] = None,
353
+ batch_size: int = 1,
354
+ num_channels_latents: int = 16,
355
+ height: int = 60,
356
+ width: int = 90,
357
+ dtype: Optional[torch.dtype] = None,
358
+ device: Optional[torch.device] = None,
359
+ generator: Optional[torch.Generator] = None,
360
+ latents: Optional[torch.Tensor] = None,
361
+ timestep: Optional[torch.Tensor] = None,
362
+ ):
363
+ if isinstance(generator, list) and len(generator) != batch_size:
364
+ raise ValueError(
365
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
366
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
367
+ )
368
+
369
+ num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
370
+
371
+ shape = (
372
+ batch_size,
373
+ num_frames,
374
+ num_channels_latents,
375
+ height // self.vae_scale_factor_spatial,
376
+ width // self.vae_scale_factor_spatial,
377
+ )
378
+
379
+ if latents is None:
380
+ if isinstance(generator, list):
381
+ init_latents = [
382
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
383
+ ]
384
+ else:
385
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
386
+
387
+ init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
388
+ init_latents = self.vae_scaling_factor_image * init_latents
389
+
390
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
391
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
392
+ else:
393
+ latents = latents.to(device)
394
+
395
+ # scale the initial noise by the standard deviation required by the scheduler
396
+ latents = latents * self.scheduler.init_noise_sigma
397
+ return latents
398
+
399
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
400
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
401
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
402
+ latents = 1 / self.vae_scaling_factor_image * latents
403
+
404
+ frames = self.vae.decode(latents).sample
405
+ return frames
406
+
407
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
408
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
409
+ # get the original timestep using init_timestep
410
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
411
+
412
+ t_start = max(num_inference_steps - init_timestep, 0)
413
+ timesteps = timesteps[t_start * self.scheduler.order :]
414
+
415
+ return timesteps, num_inference_steps - t_start
416
+
417
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
418
+ def prepare_extra_step_kwargs(self, generator, eta):
419
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
420
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
421
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
422
+ # and should be between [0, 1]
423
+
424
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
425
+ extra_step_kwargs = {}
426
+ if accepts_eta:
427
+ extra_step_kwargs["eta"] = eta
428
+
429
+ # check if the scheduler accepts generator
430
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
431
+ if accepts_generator:
432
+ extra_step_kwargs["generator"] = generator
433
+ return extra_step_kwargs
434
+
435
+ def check_inputs(
436
+ self,
437
+ prompt,
438
+ height,
439
+ width,
440
+ strength,
441
+ negative_prompt,
442
+ callback_on_step_end_tensor_inputs,
443
+ video=None,
444
+ latents=None,
445
+ prompt_embeds=None,
446
+ negative_prompt_embeds=None,
447
+ ):
448
+ if height % 8 != 0 or width % 8 != 0:
449
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
450
+
451
+ if strength < 0 or strength > 1:
452
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
453
+
454
+ if callback_on_step_end_tensor_inputs is not None and not all(
455
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
456
+ ):
457
+ raise ValueError(
458
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
459
+ )
460
+ if prompt is not None and prompt_embeds is not None:
461
+ raise ValueError(
462
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
463
+ " only forward one of the two."
464
+ )
465
+ elif prompt is None and prompt_embeds is None:
466
+ raise ValueError(
467
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
468
+ )
469
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
470
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
471
+
472
+ if prompt is not None and negative_prompt_embeds is not None:
473
+ raise ValueError(
474
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
475
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
476
+ )
477
+
478
+ if negative_prompt is not None and negative_prompt_embeds is not None:
479
+ raise ValueError(
480
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
481
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
482
+ )
483
+
484
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
485
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
486
+ raise ValueError(
487
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
488
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
489
+ f" {negative_prompt_embeds.shape}."
490
+ )
491
+
492
+ if video is not None and latents is not None:
493
+ raise ValueError("Only one of `video` or `latents` should be provided")
494
+
495
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
496
+ def fuse_qkv_projections(self) -> None:
497
+ r"""Enables fused QKV projections."""
498
+ self.fusing_transformer = True
499
+ self.transformer.fuse_qkv_projections()
500
+
501
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
502
+ def unfuse_qkv_projections(self) -> None:
503
+ r"""Disable QKV projection fusion if enabled."""
504
+ if not self.fusing_transformer:
505
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
506
+ else:
507
+ self.transformer.unfuse_qkv_projections()
508
+ self.fusing_transformer = False
509
+
510
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
511
+ def _prepare_rotary_positional_embeddings(
512
+ self,
513
+ height: int,
514
+ width: int,
515
+ num_frames: int,
516
+ device: torch.device,
517
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
518
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
519
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
520
+
521
+ p = self.transformer.config.patch_size
522
+ p_t = self.transformer.config.patch_size_t
523
+
524
+ base_size_width = self.transformer.config.sample_width // p
525
+ base_size_height = self.transformer.config.sample_height // p
526
+
527
+ if p_t is None:
528
+ # CogVideoX 1.0
529
+ grid_crops_coords = get_resize_crop_region_for_grid(
530
+ (grid_height, grid_width), base_size_width, base_size_height
531
+ )
532
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
533
+ embed_dim=self.transformer.config.attention_head_dim,
534
+ crops_coords=grid_crops_coords,
535
+ grid_size=(grid_height, grid_width),
536
+ temporal_size=num_frames,
537
+ device=device,
538
+ )
539
+ else:
540
+ # CogVideoX 1.5
541
+ base_num_frames = (num_frames + p_t - 1) // p_t
542
+
543
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
544
+ embed_dim=self.transformer.config.attention_head_dim,
545
+ crops_coords=None,
546
+ grid_size=(grid_height, grid_width),
547
+ temporal_size=base_num_frames,
548
+ grid_type="slice",
549
+ max_size=(base_size_height, base_size_width),
550
+ device=device,
551
+ )
552
+
553
+ return freqs_cos, freqs_sin
554
+
555
+ @property
556
+ def guidance_scale(self):
557
+ return self._guidance_scale
558
+
559
+ @property
560
+ def num_timesteps(self):
561
+ return self._num_timesteps
562
+
563
+ @property
564
+ def attention_kwargs(self):
565
+ return self._attention_kwargs
566
+
567
+ @property
568
+ def current_timestep(self):
569
+ return self._current_timestep
570
+
571
+ @property
572
+ def interrupt(self):
573
+ return self._interrupt
574
+
575
+ @torch.no_grad()
576
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
577
+ def __call__(
578
+ self,
579
+ video: List[Image.Image] = None,
580
+ prompt: Optional[Union[str, List[str]]] = None,
581
+ negative_prompt: Optional[Union[str, List[str]]] = None,
582
+ height: Optional[int] = None,
583
+ width: Optional[int] = None,
584
+ num_inference_steps: int = 50,
585
+ timesteps: Optional[List[int]] = None,
586
+ strength: float = 0.8,
587
+ guidance_scale: float = 6,
588
+ use_dynamic_cfg: bool = False,
589
+ num_videos_per_prompt: int = 1,
590
+ eta: float = 0.0,
591
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
592
+ latents: Optional[torch.FloatTensor] = None,
593
+ prompt_embeds: Optional[torch.FloatTensor] = None,
594
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
595
+ output_type: str = "pil",
596
+ return_dict: bool = True,
597
+ attention_kwargs: Optional[Dict[str, Any]] = None,
598
+ callback_on_step_end: Optional[
599
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
600
+ ] = None,
601
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
602
+ max_sequence_length: int = 226,
603
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
604
+ """
605
+ Function invoked when calling the pipeline for generation.
606
+
607
+ Args:
608
+ video (`List[PIL.Image.Image]`):
609
+ The input video to condition the generation on. Must be a list of images/frames of the video.
610
+ prompt (`str` or `List[str]`, *optional*):
611
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
612
+ instead.
613
+ negative_prompt (`str` or `List[str]`, *optional*):
614
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
615
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
616
+ less than `1`).
617
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
618
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
619
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
620
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
621
+ num_inference_steps (`int`, *optional*, defaults to 50):
622
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
623
+ expense of slower inference.
624
+ timesteps (`List[int]`, *optional*):
625
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
626
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
627
+ passed will be used. Must be in descending order.
628
+ strength (`float`, *optional*, defaults to 0.8):
629
+ Higher strength leads to more differences between original video and generated video.
630
+ guidance_scale (`float`, *optional*, defaults to 7.0):
631
+ Guidance scale as defined in [Classifier-Free Diffusion
632
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
633
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
634
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
635
+ the text `prompt`, usually at the expense of lower image quality.
636
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
637
+ The number of videos to generate per prompt.
638
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
639
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
640
+ to make generation deterministic.
641
+ latents (`torch.FloatTensor`, *optional*):
642
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
643
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
644
+ tensor will be generated by sampling using the supplied random `generator`.
645
+ prompt_embeds (`torch.FloatTensor`, *optional*):
646
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
647
+ provided, text embeddings will be generated from `prompt` input argument.
648
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
649
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
650
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
651
+ argument.
652
+ output_type (`str`, *optional*, defaults to `"pil"`):
653
+ The output format of the generate image. Choose between
654
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
655
+ return_dict (`bool`, *optional*, defaults to `True`):
656
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
657
+ of a plain tuple.
658
+ attention_kwargs (`dict`, *optional*):
659
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
+ `self.processor` in
661
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
+ callback_on_step_end (`Callable`, *optional*):
663
+ A function that calls at the end of each denoising steps during the inference. The function is called
664
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
+ `callback_on_step_end_tensor_inputs`.
667
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
668
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
+ max_sequence_length (`int`, defaults to `226`):
672
+ Maximum sequence length in encoded prompt. Must be consistent with
673
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
674
+
675
+ Examples:
676
+
677
+ Returns:
678
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
679
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
680
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
681
+ """
682
+
683
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
684
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
685
+
686
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
687
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
688
+ num_frames = len(video) if latents is None else latents.size(1)
689
+
690
+ num_videos_per_prompt = 1
691
+
692
+ # 1. Check inputs. Raise error if not correct
693
+ self.check_inputs(
694
+ prompt=prompt,
695
+ height=height,
696
+ width=width,
697
+ strength=strength,
698
+ negative_prompt=negative_prompt,
699
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
700
+ video=video,
701
+ latents=latents,
702
+ prompt_embeds=prompt_embeds,
703
+ negative_prompt_embeds=negative_prompt_embeds,
704
+ )
705
+ self._guidance_scale = guidance_scale
706
+ self._attention_kwargs = attention_kwargs
707
+ self._current_timestep = None
708
+ self._interrupt = False
709
+
710
+ # 2. Default call parameters
711
+ if prompt is not None and isinstance(prompt, str):
712
+ batch_size = 1
713
+ elif prompt is not None and isinstance(prompt, list):
714
+ batch_size = len(prompt)
715
+ else:
716
+ batch_size = prompt_embeds.shape[0]
717
+
718
+ device = self._execution_device
719
+
720
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
721
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
722
+ # corresponds to doing no classifier free guidance.
723
+ do_classifier_free_guidance = guidance_scale > 1.0
724
+
725
+ # 3. Encode input prompt
726
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
727
+ prompt,
728
+ negative_prompt,
729
+ do_classifier_free_guidance,
730
+ num_videos_per_prompt=num_videos_per_prompt,
731
+ prompt_embeds=prompt_embeds,
732
+ negative_prompt_embeds=negative_prompt_embeds,
733
+ max_sequence_length=max_sequence_length,
734
+ device=device,
735
+ )
736
+ if do_classifier_free_guidance:
737
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
738
+
739
+ # 4. Prepare timesteps
740
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
741
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
742
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
743
+ self._num_timesteps = len(timesteps)
744
+
745
+ # 5. Prepare latents
746
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
747
+
748
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
749
+ patch_size_t = self.transformer.config.patch_size_t
750
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
751
+ raise ValueError(
752
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
753
+ f"contains {latent_frames=}, which is not divisible."
754
+ )
755
+
756
+ if latents is None:
757
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
758
+ video = video.to(device=device, dtype=prompt_embeds.dtype)
759
+
760
+ latent_channels = self.transformer.config.in_channels
761
+ latents = self.prepare_latents(
762
+ video,
763
+ batch_size * num_videos_per_prompt,
764
+ latent_channels,
765
+ height,
766
+ width,
767
+ prompt_embeds.dtype,
768
+ device,
769
+ generator,
770
+ latents,
771
+ latent_timestep,
772
+ )
773
+
774
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
775
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
776
+
777
+ # 7. Create rotary embeds if required
778
+ image_rotary_emb = (
779
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
780
+ if self.transformer.config.use_rotary_positional_embeddings
781
+ else None
782
+ )
783
+
784
+ # 8. Denoising loop
785
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
786
+
787
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
788
+ # for DPM-solver++
789
+ old_pred_original_sample = None
790
+ for i, t in enumerate(timesteps):
791
+ if self.interrupt:
792
+ continue
793
+
794
+ self._current_timestep = t
795
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
796
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
797
+
798
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
799
+ timestep = t.expand(latent_model_input.shape[0])
800
+
801
+ # predict noise model_output
802
+ with self.transformer.cache_context("cond_uncond"):
803
+ noise_pred = self.transformer(
804
+ hidden_states=latent_model_input,
805
+ encoder_hidden_states=prompt_embeds,
806
+ timestep=timestep,
807
+ image_rotary_emb=image_rotary_emb,
808
+ attention_kwargs=attention_kwargs,
809
+ return_dict=False,
810
+ )[0]
811
+ noise_pred = noise_pred.float()
812
+
813
+ # perform guidance
814
+ if use_dynamic_cfg:
815
+ self._guidance_scale = 1 + guidance_scale * (
816
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
817
+ )
818
+ if do_classifier_free_guidance:
819
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
820
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
821
+
822
+ # compute the previous noisy sample x_t -> x_t-1
823
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
824
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
825
+ else:
826
+ latents, old_pred_original_sample = self.scheduler.step(
827
+ noise_pred,
828
+ old_pred_original_sample,
829
+ t,
830
+ timesteps[i - 1] if i > 0 else None,
831
+ latents,
832
+ **extra_step_kwargs,
833
+ return_dict=False,
834
+ )
835
+ latents = latents.to(prompt_embeds.dtype)
836
+
837
+ # call the callback, if provided
838
+ if callback_on_step_end is not None:
839
+ callback_kwargs = {}
840
+ for k in callback_on_step_end_tensor_inputs:
841
+ callback_kwargs[k] = locals()[k]
842
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
843
+
844
+ latents = callback_outputs.pop("latents", latents)
845
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
846
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
847
+
848
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
849
+ progress_bar.update()
850
+
851
+ if XLA_AVAILABLE:
852
+ xm.mark_step()
853
+
854
+ self._current_timestep = None
855
+
856
+ if not output_type == "latent":
857
+ video = self.decode_latents(latents)
858
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
859
+ else:
860
+ video = latents
861
+
862
+ # Offload all models
863
+ self.maybe_free_model_hooks()
864
+
865
+ if not return_dict:
866
+ return (video,)
867
+
868
+ return CogVideoXPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogvideo/pipeline_output.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from diffusers.utils import BaseOutput
6
+
7
+
8
+ @dataclass
9
+ class CogVideoXPipelineOutput(BaseOutput):
10
+ r"""
11
+ Output class for CogVideo pipelines.
12
+
13
+ Args:
14
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17
+ `(batch_size, num_frames, channels, height, width)`.
18
+ """
19
+
20
+ frames: torch.Tensor
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _additional_imports = {}
15
+ _import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
26
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
27
+ try:
28
+ if not (is_transformers_available() and is_torch_available()):
29
+ raise OptionalDependencyNotAvailable()
30
+ except OptionalDependencyNotAvailable:
31
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
32
+ else:
33
+ from .pipeline_cogview3plus import CogView3PlusPipeline
34
+ else:
35
+ import sys
36
+
37
+ sys.modules[__name__] = _LazyModule(
38
+ __name__,
39
+ globals()["__file__"],
40
+ _import_structure,
41
+ module_spec=__spec__,
42
+ )
43
+
44
+ for name, value in _dummy_objects.items():
45
+ setattr(sys.modules[__name__], name, value)
46
+ for name, value in _additional_imports.items():
47
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_cogview3plus.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import T5EncoderModel, T5Tokenizer
21
+
22
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23
+ from ...image_processor import VaeImageProcessor
24
+ from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
25
+ from ...pipelines.pipeline_utils import DiffusionPipeline
26
+ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
27
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
28
+ from ...utils.torch_utils import randn_tensor
29
+ from .pipeline_output import CogView3PipelineOutput
30
+
31
+
32
+ if is_torch_xla_available():
33
+ import torch_xla.core.xla_model as xm
34
+
35
+ XLA_AVAILABLE = True
36
+ else:
37
+ XLA_AVAILABLE = False
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```python
45
+ >>> import torch
46
+ >>> from diffusers import CogView3PlusPipeline
47
+
48
+ >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
52
+ >>> image = pipe(prompt).images[0]
53
+ >>> image.save("output.png")
54
+ ```
55
+ """
56
+
57
+
58
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
59
+ def retrieve_timesteps(
60
+ scheduler,
61
+ num_inference_steps: Optional[int] = None,
62
+ device: Optional[Union[str, torch.device]] = None,
63
+ timesteps: Optional[List[int]] = None,
64
+ sigmas: Optional[List[float]] = None,
65
+ **kwargs,
66
+ ):
67
+ r"""
68
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
69
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
70
+
71
+ Args:
72
+ scheduler (`SchedulerMixin`):
73
+ The scheduler to get timesteps from.
74
+ num_inference_steps (`int`):
75
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
76
+ must be `None`.
77
+ device (`str` or `torch.device`, *optional*):
78
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
79
+ timesteps (`List[int]`, *optional*):
80
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
81
+ `num_inference_steps` and `sigmas` must be `None`.
82
+ sigmas (`List[float]`, *optional*):
83
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
84
+ `num_inference_steps` and `timesteps` must be `None`.
85
+
86
+ Returns:
87
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
88
+ second element is the number of inference steps.
89
+ """
90
+ if timesteps is not None and sigmas is not None:
91
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
92
+ if timesteps is not None:
93
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
94
+ if not accepts_timesteps:
95
+ raise ValueError(
96
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
97
+ f" timestep schedules. Please check whether you are using the correct scheduler."
98
+ )
99
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
100
+ timesteps = scheduler.timesteps
101
+ num_inference_steps = len(timesteps)
102
+ elif sigmas is not None:
103
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104
+ if not accept_sigmas:
105
+ raise ValueError(
106
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
107
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
108
+ )
109
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ num_inference_steps = len(timesteps)
112
+ else:
113
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
114
+ timesteps = scheduler.timesteps
115
+ return timesteps, num_inference_steps
116
+
117
+
118
+ class CogView3PlusPipeline(DiffusionPipeline):
119
+ r"""
120
+ Pipeline for text-to-image generation using CogView3Plus.
121
+
122
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
123
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
124
+
125
+ Args:
126
+ vae ([`AutoencoderKL`]):
127
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
128
+ text_encoder ([`T5EncoderModel`]):
129
+ Frozen text-encoder. CogView3Plus uses
130
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
131
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
132
+ tokenizer (`T5Tokenizer`):
133
+ Tokenizer of class
134
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
135
+ transformer ([`CogView3PlusTransformer2DModel`]):
136
+ A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents.
137
+ scheduler ([`SchedulerMixin`]):
138
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
139
+ """
140
+
141
+ _optional_components = []
142
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
143
+
144
+ _callback_tensor_inputs = [
145
+ "latents",
146
+ "prompt_embeds",
147
+ "negative_prompt_embeds",
148
+ ]
149
+
150
+ def __init__(
151
+ self,
152
+ tokenizer: T5Tokenizer,
153
+ text_encoder: T5EncoderModel,
154
+ vae: AutoencoderKL,
155
+ transformer: CogView3PlusTransformer2DModel,
156
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
157
+ ):
158
+ super().__init__()
159
+
160
+ self.register_modules(
161
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
162
+ )
163
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
164
+
165
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
166
+
167
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt
168
+ def _get_t5_prompt_embeds(
169
+ self,
170
+ prompt: Union[str, List[str]] = None,
171
+ num_images_per_prompt: int = 1,
172
+ max_sequence_length: int = 226,
173
+ device: Optional[torch.device] = None,
174
+ dtype: Optional[torch.dtype] = None,
175
+ ):
176
+ device = device or self._execution_device
177
+ dtype = dtype or self.text_encoder.dtype
178
+
179
+ prompt = [prompt] if isinstance(prompt, str) else prompt
180
+ batch_size = len(prompt)
181
+
182
+ text_inputs = self.tokenizer(
183
+ prompt,
184
+ padding="max_length",
185
+ max_length=max_sequence_length,
186
+ truncation=True,
187
+ add_special_tokens=True,
188
+ return_tensors="pt",
189
+ )
190
+ text_input_ids = text_inputs.input_ids
191
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
192
+
193
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
194
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
195
+ logger.warning(
196
+ "The following part of your input was truncated because `max_sequence_length` is set to "
197
+ f" {max_sequence_length} tokens: {removed_text}"
198
+ )
199
+
200
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
201
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
202
+
203
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
204
+ _, seq_len, _ = prompt_embeds.shape
205
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
206
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
207
+
208
+ return prompt_embeds
209
+
210
+ def encode_prompt(
211
+ self,
212
+ prompt: Union[str, List[str]],
213
+ negative_prompt: Optional[Union[str, List[str]]] = None,
214
+ do_classifier_free_guidance: bool = True,
215
+ num_images_per_prompt: int = 1,
216
+ prompt_embeds: Optional[torch.Tensor] = None,
217
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
218
+ max_sequence_length: int = 224,
219
+ device: Optional[torch.device] = None,
220
+ dtype: Optional[torch.dtype] = None,
221
+ ):
222
+ r"""
223
+ Encodes the prompt into text encoder hidden states.
224
+
225
+ Args:
226
+ prompt (`str` or `List[str]`, *optional*):
227
+ prompt to be encoded
228
+ negative_prompt (`str` or `List[str]`, *optional*):
229
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
230
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
231
+ less than `1`).
232
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
233
+ Whether to use classifier free guidance or not.
234
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
235
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
236
+ prompt_embeds (`torch.Tensor`, *optional*):
237
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
238
+ provided, text embeddings will be generated from `prompt` input argument.
239
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
240
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
241
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
242
+ argument.
243
+ max_sequence_length (`int`, defaults to `224`):
244
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
245
+ device: (`torch.device`, *optional*):
246
+ torch device
247
+ dtype: (`torch.dtype`, *optional*):
248
+ torch dtype
249
+ """
250
+ device = device or self._execution_device
251
+
252
+ prompt = [prompt] if isinstance(prompt, str) else prompt
253
+ if prompt is not None:
254
+ batch_size = len(prompt)
255
+ else:
256
+ batch_size = prompt_embeds.shape[0]
257
+
258
+ if prompt_embeds is None:
259
+ prompt_embeds = self._get_t5_prompt_embeds(
260
+ prompt=prompt,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ max_sequence_length=max_sequence_length,
263
+ device=device,
264
+ dtype=dtype,
265
+ )
266
+
267
+ if do_classifier_free_guidance and negative_prompt is None:
268
+ negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
269
+
270
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
271
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
272
+
273
+ if prompt is not None and type(prompt) is not type(negative_prompt):
274
+ raise TypeError(
275
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
276
+ f" {type(prompt)}."
277
+ )
278
+ elif batch_size != len(negative_prompt):
279
+ raise ValueError(
280
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
281
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
282
+ " the batch size of `prompt`."
283
+ )
284
+
285
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
286
+ prompt=negative_prompt,
287
+ num_images_per_prompt=num_images_per_prompt,
288
+ max_sequence_length=max_sequence_length,
289
+ device=device,
290
+ dtype=dtype,
291
+ )
292
+
293
+ return prompt_embeds, negative_prompt_embeds
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
296
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
297
+ shape = (
298
+ batch_size,
299
+ num_channels_latents,
300
+ int(height) // self.vae_scale_factor,
301
+ int(width) // self.vae_scale_factor,
302
+ )
303
+ if isinstance(generator, list) and len(generator) != batch_size:
304
+ raise ValueError(
305
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
306
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
307
+ )
308
+
309
+ if latents is None:
310
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
311
+ else:
312
+ latents = latents.to(device)
313
+
314
+ # scale the initial noise by the standard deviation required by the scheduler
315
+ latents = latents * self.scheduler.init_noise_sigma
316
+ return latents
317
+
318
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
319
+ def prepare_extra_step_kwargs(self, generator, eta):
320
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
321
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
322
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
323
+ # and should be between [0, 1]
324
+
325
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
326
+ extra_step_kwargs = {}
327
+ if accepts_eta:
328
+ extra_step_kwargs["eta"] = eta
329
+
330
+ # check if the scheduler accepts generator
331
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
332
+ if accepts_generator:
333
+ extra_step_kwargs["generator"] = generator
334
+ return extra_step_kwargs
335
+
336
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
337
+ def check_inputs(
338
+ self,
339
+ prompt,
340
+ height,
341
+ width,
342
+ negative_prompt,
343
+ callback_on_step_end_tensor_inputs,
344
+ prompt_embeds=None,
345
+ negative_prompt_embeds=None,
346
+ ):
347
+ if height % 8 != 0 or width % 8 != 0:
348
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
349
+
350
+ if callback_on_step_end_tensor_inputs is not None and not all(
351
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
352
+ ):
353
+ raise ValueError(
354
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
355
+ )
356
+ if prompt is not None and prompt_embeds is not None:
357
+ raise ValueError(
358
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
359
+ " only forward one of the two."
360
+ )
361
+ elif prompt is None and prompt_embeds is None:
362
+ raise ValueError(
363
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
364
+ )
365
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
366
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
367
+
368
+ if prompt is not None and negative_prompt_embeds is not None:
369
+ raise ValueError(
370
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
371
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
372
+ )
373
+
374
+ if negative_prompt is not None and negative_prompt_embeds is not None:
375
+ raise ValueError(
376
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
377
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
378
+ )
379
+
380
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
381
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
382
+ raise ValueError(
383
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
384
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
385
+ f" {negative_prompt_embeds.shape}."
386
+ )
387
+
388
+ @property
389
+ def guidance_scale(self):
390
+ return self._guidance_scale
391
+
392
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
393
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
394
+ # corresponds to doing no classifier free guidance.
395
+ @property
396
+ def do_classifier_free_guidance(self):
397
+ return self._guidance_scale > 1
398
+
399
+ @property
400
+ def num_timesteps(self):
401
+ return self._num_timesteps
402
+
403
+ @property
404
+ def interrupt(self):
405
+ return self._interrupt
406
+
407
+ @torch.no_grad()
408
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
409
+ def __call__(
410
+ self,
411
+ prompt: Optional[Union[str, List[str]]] = None,
412
+ negative_prompt: Optional[Union[str, List[str]]] = None,
413
+ height: Optional[int] = None,
414
+ width: Optional[int] = None,
415
+ num_inference_steps: int = 50,
416
+ timesteps: Optional[List[int]] = None,
417
+ guidance_scale: float = 5.0,
418
+ num_images_per_prompt: int = 1,
419
+ eta: float = 0.0,
420
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
421
+ latents: Optional[torch.FloatTensor] = None,
422
+ prompt_embeds: Optional[torch.FloatTensor] = None,
423
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
424
+ original_size: Optional[Tuple[int, int]] = None,
425
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
426
+ output_type: str = "pil",
427
+ return_dict: bool = True,
428
+ callback_on_step_end: Optional[
429
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
430
+ ] = None,
431
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
432
+ max_sequence_length: int = 224,
433
+ ) -> Union[CogView3PipelineOutput, Tuple]:
434
+ """
435
+ Function invoked when calling the pipeline for generation.
436
+
437
+ Args:
438
+ prompt (`str` or `List[str]`, *optional*):
439
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
440
+ negative_prompt (`str` or `List[str]`, *optional*):
441
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
442
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
443
+ less than `1`).
444
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
445
+ The height in pixels of the generated image. If not provided, it is set to 1024.
446
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
447
+ The width in pixels of the generated image. If not provided it is set to 1024.
448
+ num_inference_steps (`int`, *optional*, defaults to `50`):
449
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
450
+ expense of slower inference.
451
+ timesteps (`List[int]`, *optional*):
452
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
453
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
454
+ passed will be used. Must be in descending order.
455
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
456
+ Guidance scale as defined in [Classifier-Free Diffusion
457
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
458
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
459
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
460
+ the text `prompt`, usually at the expense of lower image quality.
461
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
462
+ The number of images to generate per prompt.
463
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
464
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
465
+ to make generation deterministic.
466
+ latents (`torch.FloatTensor`, *optional*):
467
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
468
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
469
+ tensor will be generated by sampling using the supplied random `generator`.
470
+ prompt_embeds (`torch.FloatTensor`, *optional*):
471
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
472
+ provided, text embeddings will be generated from `prompt` input argument.
473
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
474
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
475
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
476
+ argument.
477
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
478
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
479
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
480
+ explained in section 2.2 of
481
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
482
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
483
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
484
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
485
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
486
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
487
+ output_type (`str`, *optional*, defaults to `"pil"`):
488
+ The output format of the generate image. Choose between
489
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
490
+ return_dict (`bool`, *optional*, defaults to `True`):
491
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
492
+ of a plain tuple.
493
+ attention_kwargs (`dict`, *optional*):
494
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
495
+ `self.processor` in
496
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
497
+ callback_on_step_end (`Callable`, *optional*):
498
+ A function that calls at the end of each denoising steps during the inference. The function is called
499
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
500
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
501
+ `callback_on_step_end_tensor_inputs`.
502
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
503
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
504
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
505
+ `._callback_tensor_inputs` attribute of your pipeline class.
506
+ max_sequence_length (`int`, defaults to `224`):
507
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
508
+
509
+ Examples:
510
+
511
+ Returns:
512
+ [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`:
513
+ [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
514
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
515
+ """
516
+
517
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
518
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
519
+
520
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
521
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
522
+
523
+ original_size = original_size or (height, width)
524
+ target_size = (height, width)
525
+
526
+ # 1. Check inputs. Raise error if not correct
527
+ self.check_inputs(
528
+ prompt,
529
+ height,
530
+ width,
531
+ negative_prompt,
532
+ callback_on_step_end_tensor_inputs,
533
+ prompt_embeds,
534
+ negative_prompt_embeds,
535
+ )
536
+ self._guidance_scale = guidance_scale
537
+ self._interrupt = False
538
+
539
+ # 2. Default call parameters
540
+ if prompt is not None and isinstance(prompt, str):
541
+ batch_size = 1
542
+ elif prompt is not None and isinstance(prompt, list):
543
+ batch_size = len(prompt)
544
+ else:
545
+ batch_size = prompt_embeds.shape[0]
546
+
547
+ device = self._execution_device
548
+
549
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
550
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
551
+ # corresponds to doing no classifier free guidance.
552
+ do_classifier_free_guidance = guidance_scale > 1.0
553
+
554
+ # 3. Encode input prompt
555
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
556
+ prompt,
557
+ negative_prompt,
558
+ self.do_classifier_free_guidance,
559
+ num_images_per_prompt=num_images_per_prompt,
560
+ prompt_embeds=prompt_embeds,
561
+ negative_prompt_embeds=negative_prompt_embeds,
562
+ max_sequence_length=max_sequence_length,
563
+ device=device,
564
+ )
565
+ if self.do_classifier_free_guidance:
566
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
567
+
568
+ # 4. Prepare timesteps
569
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
570
+ self._num_timesteps = len(timesteps)
571
+
572
+ # 5. Prepare latents.
573
+ latent_channels = self.transformer.config.in_channels
574
+ latents = self.prepare_latents(
575
+ batch_size * num_images_per_prompt,
576
+ latent_channels,
577
+ height,
578
+ width,
579
+ prompt_embeds.dtype,
580
+ device,
581
+ generator,
582
+ latents,
583
+ )
584
+
585
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
586
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
587
+
588
+ # 7. Prepare additional timestep conditions
589
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
590
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
591
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
592
+
593
+ if self.do_classifier_free_guidance:
594
+ original_size = torch.cat([original_size, original_size])
595
+ target_size = torch.cat([target_size, target_size])
596
+ crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
597
+
598
+ original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
599
+ target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
600
+ crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
601
+
602
+ # 8. Denoising loop
603
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
604
+
605
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
606
+ # for DPM-solver++
607
+ old_pred_original_sample = None
608
+ for i, t in enumerate(timesteps):
609
+ if self.interrupt:
610
+ continue
611
+
612
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
613
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
614
+
615
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
616
+ timestep = t.expand(latent_model_input.shape[0])
617
+
618
+ # predict noise model_output
619
+ noise_pred = self.transformer(
620
+ hidden_states=latent_model_input,
621
+ encoder_hidden_states=prompt_embeds,
622
+ timestep=timestep,
623
+ original_size=original_size,
624
+ target_size=target_size,
625
+ crop_coords=crops_coords_top_left,
626
+ return_dict=False,
627
+ )[0]
628
+ noise_pred = noise_pred.float()
629
+
630
+ # perform guidance
631
+ if self.do_classifier_free_guidance:
632
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
633
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
634
+
635
+ # compute the previous noisy sample x_t -> x_t-1
636
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
637
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
638
+ else:
639
+ latents, old_pred_original_sample = self.scheduler.step(
640
+ noise_pred,
641
+ old_pred_original_sample,
642
+ t,
643
+ timesteps[i - 1] if i > 0 else None,
644
+ latents,
645
+ **extra_step_kwargs,
646
+ return_dict=False,
647
+ )
648
+ latents = latents.to(prompt_embeds.dtype)
649
+
650
+ # call the callback, if provided
651
+ if callback_on_step_end is not None:
652
+ callback_kwargs = {}
653
+ for k in callback_on_step_end_tensor_inputs:
654
+ callback_kwargs[k] = locals()[k]
655
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
656
+
657
+ latents = callback_outputs.pop("latents", latents)
658
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
659
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
660
+
661
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
662
+ progress_bar.update()
663
+
664
+ if XLA_AVAILABLE:
665
+ xm.mark_step()
666
+
667
+ if not output_type == "latent":
668
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
669
+ 0
670
+ ]
671
+ else:
672
+ image = latents
673
+
674
+ image = self.image_processor.postprocess(image, output_type=output_type)
675
+
676
+ # Offload all models
677
+ self.maybe_free_model_hooks()
678
+
679
+ if not return_dict:
680
+ return (image,)
681
+
682
+ return CogView3PipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview3/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class CogView3PipelineOutput(BaseOutput):
12
+ """
13
+ Output class for CogView3 pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _additional_imports = {}
15
+ _import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
26
+ _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+ except OptionalDependencyNotAvailable:
32
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
33
+ else:
34
+ from .pipeline_cogview4 import CogView4Pipeline
35
+ from .pipeline_cogview4_control import CogView4ControlPipeline
36
+ else:
37
+ import sys
38
+
39
+ sys.modules[__name__] = _LazyModule(
40
+ __name__,
41
+ globals()["__file__"],
42
+ _import_structure,
43
+ module_spec=__spec__,
44
+ )
45
+
46
+ for name, value in _dummy_objects.items():
47
+ setattr(sys.modules[__name__], name, value)
48
+ for name, value in _additional_imports.items():
49
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from transformers import AutoTokenizer, GlmModel
22
+
23
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from ...image_processor import VaeImageProcessor
25
+ from ...loaders import CogView4LoraLoaderMixin
26
+ from ...models import AutoencoderKL, CogView4Transformer2DModel
27
+ from ...pipelines.pipeline_utils import DiffusionPipeline
28
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
29
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
30
+ from ...utils.torch_utils import randn_tensor
31
+ from .pipeline_output import CogView4PipelineOutput
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```python
46
+ >>> import torch
47
+ >>> from diffusers import CogView4Pipeline
48
+
49
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
50
+ >>> pipe.to("cuda")
51
+
52
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
53
+ >>> image = pipe(prompt).images[0]
54
+ >>> image.save("output.png")
55
+ ```
56
+ """
57
+
58
+
59
+ def calculate_shift(
60
+ image_seq_len,
61
+ base_seq_len: int = 256,
62
+ base_shift: float = 0.25,
63
+ max_shift: float = 0.75,
64
+ ) -> float:
65
+ m = (image_seq_len / base_seq_len) ** 0.5
66
+ mu = m * max_shift + base_shift
67
+ return mu
68
+
69
+
70
+ def retrieve_timesteps(
71
+ scheduler,
72
+ num_inference_steps: Optional[int] = None,
73
+ device: Optional[Union[str, torch.device]] = None,
74
+ timesteps: Optional[List[int]] = None,
75
+ sigmas: Optional[List[float]] = None,
76
+ **kwargs,
77
+ ):
78
+ r"""
79
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
80
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
81
+
82
+ Args:
83
+ scheduler (`SchedulerMixin`):
84
+ The scheduler to get timesteps from.
85
+ num_inference_steps (`int`):
86
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
87
+ must be `None`.
88
+ device (`str` or `torch.device`, *optional*):
89
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
90
+ timesteps (`List[int]`, *optional*):
91
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
92
+ `num_inference_steps` and `sigmas` must be `None`.
93
+ sigmas (`List[float]`, *optional*):
94
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
95
+ `num_inference_steps` and `timesteps` must be `None`.
96
+
97
+ Returns:
98
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
99
+ second element is the number of inference steps.
100
+ """
101
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
102
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
103
+
104
+ if timesteps is not None and sigmas is not None:
105
+ if not accepts_timesteps and not accepts_sigmas:
106
+ raise ValueError(
107
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
108
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
109
+ )
110
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
111
+ timesteps = scheduler.timesteps
112
+ num_inference_steps = len(timesteps)
113
+ elif timesteps is not None and sigmas is None:
114
+ if not accepts_timesteps:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" timestep schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ elif timesteps is None and sigmas is not None:
123
+ if not accepts_sigmas:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ else:
132
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
133
+ timesteps = scheduler.timesteps
134
+ return timesteps, num_inference_steps
135
+
136
+
137
+ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
138
+ r"""
139
+ Pipeline for text-to-image generation using CogView4.
140
+
141
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
142
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
143
+
144
+ Args:
145
+ vae ([`AutoencoderKL`]):
146
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
147
+ text_encoder ([`GLMModel`]):
148
+ Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
149
+ tokenizer (`PreTrainedTokenizer`):
150
+ Tokenizer of class
151
+ [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
152
+ transformer ([`CogView4Transformer2DModel`]):
153
+ A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
154
+ scheduler ([`SchedulerMixin`]):
155
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
156
+ """
157
+
158
+ _optional_components = []
159
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
160
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
161
+
162
+ def __init__(
163
+ self,
164
+ tokenizer: AutoTokenizer,
165
+ text_encoder: GlmModel,
166
+ vae: AutoencoderKL,
167
+ transformer: CogView4Transformer2DModel,
168
+ scheduler: FlowMatchEulerDiscreteScheduler,
169
+ ):
170
+ super().__init__()
171
+
172
+ self.register_modules(
173
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
174
+ )
175
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
176
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
177
+
178
+ def _get_glm_embeds(
179
+ self,
180
+ prompt: Union[str, List[str]] = None,
181
+ max_sequence_length: int = 1024,
182
+ device: Optional[torch.device] = None,
183
+ dtype: Optional[torch.dtype] = None,
184
+ ):
185
+ device = device or self._execution_device
186
+ dtype = dtype or self.text_encoder.dtype
187
+
188
+ prompt = [prompt] if isinstance(prompt, str) else prompt
189
+
190
+ text_inputs = self.tokenizer(
191
+ prompt,
192
+ padding="longest", # not use max length
193
+ max_length=max_sequence_length,
194
+ truncation=True,
195
+ add_special_tokens=True,
196
+ return_tensors="pt",
197
+ )
198
+ text_input_ids = text_inputs.input_ids
199
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
200
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
201
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
202
+ logger.warning(
203
+ "The following part of your input was truncated because `max_sequence_length` is set to "
204
+ f" {max_sequence_length} tokens: {removed_text}"
205
+ )
206
+ current_length = text_input_ids.shape[1]
207
+ pad_length = (16 - (current_length % 16)) % 16
208
+ if pad_length > 0:
209
+ pad_ids = torch.full(
210
+ (text_input_ids.shape[0], pad_length),
211
+ fill_value=self.tokenizer.pad_token_id,
212
+ dtype=text_input_ids.dtype,
213
+ device=text_input_ids.device,
214
+ )
215
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
216
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
217
+
218
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
219
+ return prompt_embeds
220
+
221
+ def encode_prompt(
222
+ self,
223
+ prompt: Union[str, List[str]],
224
+ negative_prompt: Optional[Union[str, List[str]]] = None,
225
+ do_classifier_free_guidance: bool = True,
226
+ num_images_per_prompt: int = 1,
227
+ prompt_embeds: Optional[torch.Tensor] = None,
228
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ max_sequence_length: int = 1024,
232
+ ):
233
+ r"""
234
+ Encodes the prompt into text encoder hidden states.
235
+
236
+ Args:
237
+ prompt (`str` or `List[str]`, *optional*):
238
+ prompt to be encoded
239
+ negative_prompt (`str` or `List[str]`, *optional*):
240
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
241
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
242
+ less than `1`).
243
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
244
+ Whether to use classifier free guidance or not.
245
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
246
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
247
+ prompt_embeds (`torch.Tensor`, *optional*):
248
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
249
+ provided, text embeddings will be generated from `prompt` input argument.
250
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
251
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
252
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
253
+ argument.
254
+ device: (`torch.device`, *optional*):
255
+ torch device
256
+ dtype: (`torch.dtype`, *optional*):
257
+ torch dtype
258
+ max_sequence_length (`int`, defaults to `1024`):
259
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
260
+ """
261
+ device = device or self._execution_device
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+ if prompt is not None:
265
+ batch_size = len(prompt)
266
+ else:
267
+ batch_size = prompt_embeds.shape[0]
268
+
269
+ if prompt_embeds is None:
270
+ prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
271
+
272
+ seq_len = prompt_embeds.size(1)
273
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
274
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
275
+
276
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
277
+ negative_prompt = negative_prompt or ""
278
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
279
+
280
+ if prompt is not None and type(prompt) is not type(negative_prompt):
281
+ raise TypeError(
282
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
283
+ f" {type(prompt)}."
284
+ )
285
+ elif batch_size != len(negative_prompt):
286
+ raise ValueError(
287
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
288
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
289
+ " the batch size of `prompt`."
290
+ )
291
+
292
+ negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
293
+
294
+ seq_len = negative_prompt_embeds.size(1)
295
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
296
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
297
+
298
+ return prompt_embeds, negative_prompt_embeds
299
+
300
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
301
+ if latents is not None:
302
+ return latents.to(device)
303
+
304
+ shape = (
305
+ batch_size,
306
+ num_channels_latents,
307
+ int(height) // self.vae_scale_factor,
308
+ int(width) // self.vae_scale_factor,
309
+ )
310
+ if isinstance(generator, list) and len(generator) != batch_size:
311
+ raise ValueError(
312
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
313
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
314
+ )
315
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
316
+ return latents
317
+
318
+ def check_inputs(
319
+ self,
320
+ prompt,
321
+ height,
322
+ width,
323
+ negative_prompt,
324
+ callback_on_step_end_tensor_inputs,
325
+ prompt_embeds=None,
326
+ negative_prompt_embeds=None,
327
+ ):
328
+ if height % 16 != 0 or width % 16 != 0:
329
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
330
+
331
+ if callback_on_step_end_tensor_inputs is not None and not all(
332
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
333
+ ):
334
+ raise ValueError(
335
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
336
+ )
337
+ if prompt is not None and prompt_embeds is not None:
338
+ raise ValueError(
339
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
340
+ " only forward one of the two."
341
+ )
342
+ elif prompt is None and prompt_embeds is None:
343
+ raise ValueError(
344
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
345
+ )
346
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
347
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
348
+
349
+ if prompt is not None and negative_prompt_embeds is not None:
350
+ raise ValueError(
351
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
352
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
353
+ )
354
+
355
+ if negative_prompt is not None and negative_prompt_embeds is not None:
356
+ raise ValueError(
357
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
358
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
359
+ )
360
+
361
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
362
+ if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
363
+ raise ValueError(
364
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
365
+ f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
366
+ f" {negative_prompt_embeds.shape}."
367
+ )
368
+ if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
369
+ raise ValueError(
370
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
371
+ f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
372
+ f" {negative_prompt_embeds.shape}."
373
+ )
374
+
375
+ @property
376
+ def guidance_scale(self):
377
+ return self._guidance_scale
378
+
379
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
380
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
381
+ # corresponds to doing no classifier free guidance.
382
+ @property
383
+ def do_classifier_free_guidance(self):
384
+ return self._guidance_scale > 1
385
+
386
+ @property
387
+ def num_timesteps(self):
388
+ return self._num_timesteps
389
+
390
+ @property
391
+ def attention_kwargs(self):
392
+ return self._attention_kwargs
393
+
394
+ @property
395
+ def current_timestep(self):
396
+ return self._current_timestep
397
+
398
+ @property
399
+ def interrupt(self):
400
+ return self._interrupt
401
+
402
+ @torch.no_grad()
403
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
404
+ def __call__(
405
+ self,
406
+ prompt: Optional[Union[str, List[str]]] = None,
407
+ negative_prompt: Optional[Union[str, List[str]]] = None,
408
+ height: Optional[int] = None,
409
+ width: Optional[int] = None,
410
+ num_inference_steps: int = 50,
411
+ timesteps: Optional[List[int]] = None,
412
+ sigmas: Optional[List[float]] = None,
413
+ guidance_scale: float = 5.0,
414
+ num_images_per_prompt: int = 1,
415
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
416
+ latents: Optional[torch.FloatTensor] = None,
417
+ prompt_embeds: Optional[torch.FloatTensor] = None,
418
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
419
+ original_size: Optional[Tuple[int, int]] = None,
420
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
421
+ output_type: str = "pil",
422
+ return_dict: bool = True,
423
+ attention_kwargs: Optional[Dict[str, Any]] = None,
424
+ callback_on_step_end: Optional[
425
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
426
+ ] = None,
427
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
428
+ max_sequence_length: int = 1024,
429
+ ) -> Union[CogView4PipelineOutput, Tuple]:
430
+ """
431
+ Function invoked when calling the pipeline for generation.
432
+
433
+ Args:
434
+ prompt (`str` or `List[str]`, *optional*):
435
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
436
+ negative_prompt (`str` or `List[str]`, *optional*):
437
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
438
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
439
+ less than `1`).
440
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
441
+ The height in pixels of the generated image. If not provided, it is set to 1024.
442
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
443
+ The width in pixels of the generated image. If not provided it is set to 1024.
444
+ num_inference_steps (`int`, *optional*, defaults to `50`):
445
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
446
+ expense of slower inference.
447
+ timesteps (`List[int]`, *optional*):
448
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
449
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
450
+ passed will be used. Must be in descending order.
451
+ sigmas (`List[float]`, *optional*):
452
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
453
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
454
+ will be used.
455
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
456
+ Guidance scale as defined in [Classifier-Free Diffusion
457
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
458
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
459
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
460
+ the text `prompt`, usually at the expense of lower image quality.
461
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
462
+ The number of images to generate per prompt.
463
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
464
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
465
+ to make generation deterministic.
466
+ latents (`torch.FloatTensor`, *optional*):
467
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
468
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
469
+ tensor will be generated by sampling using the supplied random `generator`.
470
+ prompt_embeds (`torch.FloatTensor`, *optional*):
471
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
472
+ provided, text embeddings will be generated from `prompt` input argument.
473
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
474
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
475
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
476
+ argument.
477
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
478
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
479
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
480
+ explained in section 2.2 of
481
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
482
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
483
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
484
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
485
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
486
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
487
+ output_type (`str`, *optional*, defaults to `"pil"`):
488
+ The output format of the generate image. Choose between
489
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
490
+ return_dict (`bool`, *optional*, defaults to `True`):
491
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
492
+ of a plain tuple.
493
+ attention_kwargs (`dict`, *optional*):
494
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
495
+ `self.processor` in
496
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
497
+ callback_on_step_end (`Callable`, *optional*):
498
+ A function that calls at the end of each denoising steps during the inference. The function is called
499
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
500
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
501
+ `callback_on_step_end_tensor_inputs`.
502
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
503
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
504
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
505
+ `._callback_tensor_inputs` attribute of your pipeline class.
506
+ max_sequence_length (`int`, defaults to `224`):
507
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
508
+
509
+ Examples:
510
+
511
+ Returns:
512
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
513
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
514
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
515
+ """
516
+
517
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
518
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
519
+
520
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
521
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
522
+
523
+ original_size = original_size or (height, width)
524
+ target_size = (height, width)
525
+
526
+ # Check inputs. Raise error if not correct
527
+ self.check_inputs(
528
+ prompt,
529
+ height,
530
+ width,
531
+ negative_prompt,
532
+ callback_on_step_end_tensor_inputs,
533
+ prompt_embeds,
534
+ negative_prompt_embeds,
535
+ )
536
+ self._guidance_scale = guidance_scale
537
+ self._attention_kwargs = attention_kwargs
538
+ self._current_timestep = None
539
+ self._interrupt = False
540
+
541
+ # Default call parameters
542
+ if prompt is not None and isinstance(prompt, str):
543
+ batch_size = 1
544
+ elif prompt is not None and isinstance(prompt, list):
545
+ batch_size = len(prompt)
546
+ else:
547
+ batch_size = prompt_embeds.shape[0]
548
+
549
+ device = self._execution_device
550
+
551
+ # Encode input prompt
552
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
553
+ prompt,
554
+ negative_prompt,
555
+ self.do_classifier_free_guidance,
556
+ num_images_per_prompt=num_images_per_prompt,
557
+ prompt_embeds=prompt_embeds,
558
+ negative_prompt_embeds=negative_prompt_embeds,
559
+ max_sequence_length=max_sequence_length,
560
+ device=device,
561
+ )
562
+
563
+ # Prepare latents
564
+ latent_channels = self.transformer.config.in_channels
565
+ latents = self.prepare_latents(
566
+ batch_size * num_images_per_prompt,
567
+ latent_channels,
568
+ height,
569
+ width,
570
+ torch.float32,
571
+ device,
572
+ generator,
573
+ latents,
574
+ )
575
+
576
+ # Prepare additional timestep conditions
577
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
578
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
579
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
580
+
581
+ original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
582
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
583
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
584
+
585
+ # Prepare timesteps
586
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
587
+ self.transformer.config.patch_size**2
588
+ )
589
+ timesteps = (
590
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
591
+ if timesteps is None
592
+ else np.array(timesteps)
593
+ )
594
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
595
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
596
+ mu = calculate_shift(
597
+ image_seq_len,
598
+ self.scheduler.config.get("base_image_seq_len", 256),
599
+ self.scheduler.config.get("base_shift", 0.25),
600
+ self.scheduler.config.get("max_shift", 0.75),
601
+ )
602
+ timesteps, num_inference_steps = retrieve_timesteps(
603
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
604
+ )
605
+ self._num_timesteps = len(timesteps)
606
+
607
+ # Denoising loop
608
+ transformer_dtype = self.transformer.dtype
609
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
610
+
611
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
612
+ for i, t in enumerate(timesteps):
613
+ if self.interrupt:
614
+ continue
615
+
616
+ self._current_timestep = t
617
+ latent_model_input = latents.to(transformer_dtype)
618
+
619
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
620
+ timestep = t.expand(latents.shape[0])
621
+
622
+ with self.transformer.cache_context("cond"):
623
+ noise_pred_cond = self.transformer(
624
+ hidden_states=latent_model_input,
625
+ encoder_hidden_states=prompt_embeds,
626
+ timestep=timestep,
627
+ original_size=original_size,
628
+ target_size=target_size,
629
+ crop_coords=crops_coords_top_left,
630
+ attention_kwargs=attention_kwargs,
631
+ return_dict=False,
632
+ )[0]
633
+
634
+ # perform guidance
635
+ if self.do_classifier_free_guidance:
636
+ with self.transformer.cache_context("uncond"):
637
+ noise_pred_uncond = self.transformer(
638
+ hidden_states=latent_model_input,
639
+ encoder_hidden_states=negative_prompt_embeds,
640
+ timestep=timestep,
641
+ original_size=original_size,
642
+ target_size=target_size,
643
+ crop_coords=crops_coords_top_left,
644
+ attention_kwargs=attention_kwargs,
645
+ return_dict=False,
646
+ )[0]
647
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
648
+ else:
649
+ noise_pred = noise_pred_cond
650
+
651
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
652
+
653
+ # call the callback, if provided
654
+ if callback_on_step_end is not None:
655
+ callback_kwargs = {}
656
+ for k in callback_on_step_end_tensor_inputs:
657
+ callback_kwargs[k] = locals()[k]
658
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
659
+ latents = callback_outputs.pop("latents", latents)
660
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
661
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
662
+
663
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
664
+ progress_bar.update()
665
+
666
+ if XLA_AVAILABLE:
667
+ xm.mark_step()
668
+
669
+ self._current_timestep = None
670
+
671
+ if not output_type == "latent":
672
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
673
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
674
+ else:
675
+ image = latents
676
+
677
+ image = self.image_processor.postprocess(image, output_type=output_type)
678
+
679
+ # Offload all models
680
+ self.maybe_free_model_hooks()
681
+
682
+ if not return_dict:
683
+ return (image,)
684
+
685
+ return CogView4PipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_cogview4_control.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from transformers import AutoTokenizer, GlmModel
22
+
23
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
25
+ from ...models import AutoencoderKL, CogView4Transformer2DModel
26
+ from ...pipelines.pipeline_utils import DiffusionPipeline
27
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
28
+ from ...utils import is_torch_xla_available, logging, replace_example_docstring
29
+ from ...utils.torch_utils import randn_tensor
30
+ from .pipeline_output import CogView4PipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```python
45
+ >>> import torch
46
+ >>> from diffusers import CogView4ControlPipeline
47
+
48
+ >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
49
+ >>> control_image = load_image(
50
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
51
+ ... )
52
+ >>> prompt = "A bird in space"
53
+ >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
54
+ >>> image.save("cogview4-control.png")
55
+ ```
56
+ """
57
+
58
+
59
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
60
+ def calculate_shift(
61
+ image_seq_len,
62
+ base_seq_len: int = 256,
63
+ base_shift: float = 0.25,
64
+ max_shift: float = 0.75,
65
+ ) -> float:
66
+ m = (image_seq_len / base_seq_len) ** 0.5
67
+ mu = m * max_shift + base_shift
68
+ return mu
69
+
70
+
71
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
72
+ def retrieve_timesteps(
73
+ scheduler,
74
+ num_inference_steps: Optional[int] = None,
75
+ device: Optional[Union[str, torch.device]] = None,
76
+ timesteps: Optional[List[int]] = None,
77
+ sigmas: Optional[List[float]] = None,
78
+ **kwargs,
79
+ ):
80
+ r"""
81
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
82
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
83
+
84
+ Args:
85
+ scheduler (`SchedulerMixin`):
86
+ The scheduler to get timesteps from.
87
+ num_inference_steps (`int`):
88
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
89
+ must be `None`.
90
+ device (`str` or `torch.device`, *optional*):
91
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
92
+ timesteps (`List[int]`, *optional*):
93
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
94
+ `num_inference_steps` and `sigmas` must be `None`.
95
+ sigmas (`List[float]`, *optional*):
96
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
97
+ `num_inference_steps` and `timesteps` must be `None`.
98
+
99
+ Returns:
100
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
101
+ second element is the number of inference steps.
102
+ """
103
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105
+
106
+ if timesteps is not None and sigmas is not None:
107
+ if not accepts_timesteps and not accepts_sigmas:
108
+ raise ValueError(
109
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
111
+ )
112
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
113
+ timesteps = scheduler.timesteps
114
+ num_inference_steps = len(timesteps)
115
+ elif timesteps is not None and sigmas is None:
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif timesteps is None and sigmas is not None:
125
+ if not accepts_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ class CogView4ControlPipeline(DiffusionPipeline):
140
+ r"""
141
+ Pipeline for text-to-image generation using CogView4.
142
+
143
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
144
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
145
+
146
+ Args:
147
+ vae ([`AutoencoderKL`]):
148
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
149
+ text_encoder ([`GLMModel`]):
150
+ Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
151
+ tokenizer (`PreTrainedTokenizer`):
152
+ Tokenizer of class
153
+ [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
154
+ transformer ([`CogView4Transformer2DModel`]):
155
+ A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
156
+ scheduler ([`SchedulerMixin`]):
157
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
158
+ """
159
+
160
+ _optional_components = []
161
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
162
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
163
+
164
+ def __init__(
165
+ self,
166
+ tokenizer: AutoTokenizer,
167
+ text_encoder: GlmModel,
168
+ vae: AutoencoderKL,
169
+ transformer: CogView4Transformer2DModel,
170
+ scheduler: FlowMatchEulerDiscreteScheduler,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.register_modules(
175
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
176
+ )
177
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
178
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
179
+
180
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
181
+ def _get_glm_embeds(
182
+ self,
183
+ prompt: Union[str, List[str]] = None,
184
+ max_sequence_length: int = 1024,
185
+ device: Optional[torch.device] = None,
186
+ dtype: Optional[torch.dtype] = None,
187
+ ):
188
+ device = device or self._execution_device
189
+ dtype = dtype or self.text_encoder.dtype
190
+
191
+ prompt = [prompt] if isinstance(prompt, str) else prompt
192
+
193
+ text_inputs = self.tokenizer(
194
+ prompt,
195
+ padding="longest", # not use max length
196
+ max_length=max_sequence_length,
197
+ truncation=True,
198
+ add_special_tokens=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids = text_inputs.input_ids
202
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
203
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
204
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
205
+ logger.warning(
206
+ "The following part of your input was truncated because `max_sequence_length` is set to "
207
+ f" {max_sequence_length} tokens: {removed_text}"
208
+ )
209
+ current_length = text_input_ids.shape[1]
210
+ pad_length = (16 - (current_length % 16)) % 16
211
+ if pad_length > 0:
212
+ pad_ids = torch.full(
213
+ (text_input_ids.shape[0], pad_length),
214
+ fill_value=self.tokenizer.pad_token_id,
215
+ dtype=text_input_ids.dtype,
216
+ device=text_input_ids.device,
217
+ )
218
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
219
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
220
+
221
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
222
+ return prompt_embeds
223
+
224
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
225
+ def encode_prompt(
226
+ self,
227
+ prompt: Union[str, List[str]],
228
+ negative_prompt: Optional[Union[str, List[str]]] = None,
229
+ do_classifier_free_guidance: bool = True,
230
+ num_images_per_prompt: int = 1,
231
+ prompt_embeds: Optional[torch.Tensor] = None,
232
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
233
+ device: Optional[torch.device] = None,
234
+ dtype: Optional[torch.dtype] = None,
235
+ max_sequence_length: int = 1024,
236
+ ):
237
+ r"""
238
+ Encodes the prompt into text encoder hidden states.
239
+
240
+ Args:
241
+ prompt (`str` or `List[str]`, *optional*):
242
+ prompt to be encoded
243
+ negative_prompt (`str` or `List[str]`, *optional*):
244
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
245
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
246
+ less than `1`).
247
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
248
+ Whether to use classifier free guidance or not.
249
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
250
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
251
+ prompt_embeds (`torch.Tensor`, *optional*):
252
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
253
+ provided, text embeddings will be generated from `prompt` input argument.
254
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
255
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
256
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
257
+ argument.
258
+ device: (`torch.device`, *optional*):
259
+ torch device
260
+ dtype: (`torch.dtype`, *optional*):
261
+ torch dtype
262
+ max_sequence_length (`int`, defaults to `1024`):
263
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
264
+ """
265
+ device = device or self._execution_device
266
+
267
+ prompt = [prompt] if isinstance(prompt, str) else prompt
268
+ if prompt is not None:
269
+ batch_size = len(prompt)
270
+ else:
271
+ batch_size = prompt_embeds.shape[0]
272
+
273
+ if prompt_embeds is None:
274
+ prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
275
+
276
+ seq_len = prompt_embeds.size(1)
277
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
278
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
279
+
280
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
281
+ negative_prompt = negative_prompt or ""
282
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
283
+
284
+ if prompt is not None and type(prompt) is not type(negative_prompt):
285
+ raise TypeError(
286
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
287
+ f" {type(prompt)}."
288
+ )
289
+ elif batch_size != len(negative_prompt):
290
+ raise ValueError(
291
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
292
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
293
+ " the batch size of `prompt`."
294
+ )
295
+
296
+ negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
297
+
298
+ seq_len = negative_prompt_embeds.size(1)
299
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
300
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
301
+
302
+ return prompt_embeds, negative_prompt_embeds
303
+
304
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
305
+ if latents is not None:
306
+ return latents.to(device)
307
+
308
+ shape = (
309
+ batch_size,
310
+ num_channels_latents,
311
+ int(height) // self.vae_scale_factor,
312
+ int(width) // self.vae_scale_factor,
313
+ )
314
+ if isinstance(generator, list) and len(generator) != batch_size:
315
+ raise ValueError(
316
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
317
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
318
+ )
319
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
320
+ return latents
321
+
322
+ def prepare_image(
323
+ self,
324
+ image,
325
+ width,
326
+ height,
327
+ batch_size,
328
+ num_images_per_prompt,
329
+ device,
330
+ dtype,
331
+ do_classifier_free_guidance=False,
332
+ guess_mode=False,
333
+ ):
334
+ if isinstance(image, torch.Tensor):
335
+ pass
336
+ else:
337
+ image = self.image_processor.preprocess(image, height=height, width=width)
338
+
339
+ image_batch_size = image.shape[0]
340
+
341
+ if image_batch_size == 1:
342
+ repeat_by = batch_size
343
+ else:
344
+ # image batch size is the same as prompt batch size
345
+ repeat_by = num_images_per_prompt
346
+
347
+ image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
348
+
349
+ image = image.to(device=device, dtype=dtype)
350
+
351
+ if do_classifier_free_guidance and not guess_mode:
352
+ image = torch.cat([image] * 2)
353
+
354
+ return image
355
+
356
+ def check_inputs(
357
+ self,
358
+ prompt,
359
+ height,
360
+ width,
361
+ negative_prompt,
362
+ callback_on_step_end_tensor_inputs,
363
+ prompt_embeds=None,
364
+ negative_prompt_embeds=None,
365
+ ):
366
+ if height % 16 != 0 or width % 16 != 0:
367
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
368
+
369
+ if callback_on_step_end_tensor_inputs is not None and not all(
370
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
371
+ ):
372
+ raise ValueError(
373
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
374
+ )
375
+ if prompt is not None and prompt_embeds is not None:
376
+ raise ValueError(
377
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
378
+ " only forward one of the two."
379
+ )
380
+ elif prompt is None and prompt_embeds is None:
381
+ raise ValueError(
382
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
383
+ )
384
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
385
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
386
+
387
+ if prompt is not None and negative_prompt_embeds is not None:
388
+ raise ValueError(
389
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
390
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
391
+ )
392
+
393
+ if negative_prompt is not None and negative_prompt_embeds is not None:
394
+ raise ValueError(
395
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
396
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
397
+ )
398
+
399
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
400
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
401
+ raise ValueError(
402
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
403
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
404
+ f" {negative_prompt_embeds.shape}."
405
+ )
406
+
407
+ @property
408
+ def guidance_scale(self):
409
+ return self._guidance_scale
410
+
411
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
412
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
413
+ # corresponds to doing no classifier free guidance.
414
+ @property
415
+ def do_classifier_free_guidance(self):
416
+ return self._guidance_scale > 1
417
+
418
+ @property
419
+ def num_timesteps(self):
420
+ return self._num_timesteps
421
+
422
+ @property
423
+ def attention_kwargs(self):
424
+ return self._attention_kwargs
425
+
426
+ @property
427
+ def current_timestep(self):
428
+ return self._current_timestep
429
+
430
+ @property
431
+ def interrupt(self):
432
+ return self._interrupt
433
+
434
+ @torch.no_grad()
435
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
436
+ def __call__(
437
+ self,
438
+ prompt: Optional[Union[str, List[str]]] = None,
439
+ negative_prompt: Optional[Union[str, List[str]]] = None,
440
+ control_image: PipelineImageInput = None,
441
+ height: Optional[int] = None,
442
+ width: Optional[int] = None,
443
+ num_inference_steps: int = 50,
444
+ timesteps: Optional[List[int]] = None,
445
+ sigmas: Optional[List[float]] = None,
446
+ guidance_scale: float = 5.0,
447
+ num_images_per_prompt: int = 1,
448
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
449
+ latents: Optional[torch.FloatTensor] = None,
450
+ prompt_embeds: Optional[torch.FloatTensor] = None,
451
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
452
+ original_size: Optional[Tuple[int, int]] = None,
453
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
454
+ output_type: str = "pil",
455
+ return_dict: bool = True,
456
+ attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ callback_on_step_end: Optional[
458
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
459
+ ] = None,
460
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
461
+ max_sequence_length: int = 1024,
462
+ ) -> Union[CogView4PipelineOutput, Tuple]:
463
+ """
464
+ Function invoked when calling the pipeline for generation.
465
+
466
+ Args:
467
+ prompt (`str` or `List[str]`, *optional*):
468
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
469
+ negative_prompt (`str` or `List[str]`, *optional*):
470
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
471
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
472
+ less than `1`).
473
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
474
+ The height in pixels of the generated image. If not provided, it is set to 1024.
475
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
476
+ The width in pixels of the generated image. If not provided it is set to 1024.
477
+ num_inference_steps (`int`, *optional*, defaults to `50`):
478
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
479
+ expense of slower inference.
480
+ timesteps (`List[int]`, *optional*):
481
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
482
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
483
+ passed will be used. Must be in descending order.
484
+ sigmas (`List[float]`, *optional*):
485
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
486
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
487
+ will be used.
488
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
489
+ Guidance scale as defined in [Classifier-Free Diffusion
490
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
491
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
492
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
493
+ the text `prompt`, usually at the expense of lower image quality.
494
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
495
+ The number of images to generate per prompt.
496
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
497
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
498
+ to make generation deterministic.
499
+ latents (`torch.FloatTensor`, *optional*):
500
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
501
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
502
+ tensor will be generated by sampling using the supplied random `generator`.
503
+ prompt_embeds (`torch.FloatTensor`, *optional*):
504
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
505
+ provided, text embeddings will be generated from `prompt` input argument.
506
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
507
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
508
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
509
+ argument.
510
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
511
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
512
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
513
+ explained in section 2.2 of
514
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
515
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
516
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
517
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
518
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
519
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
520
+ output_type (`str`, *optional*, defaults to `"pil"`):
521
+ The output format of the generate image. Choose between
522
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
523
+ return_dict (`bool`, *optional*, defaults to `True`):
524
+ Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
525
+ tuple.
526
+ attention_kwargs (`dict`, *optional*):
527
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
528
+ `self.processor` in
529
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
530
+ callback_on_step_end (`Callable`, *optional*):
531
+ A function that calls at the end of each denoising steps during the inference. The function is called
532
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
533
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
534
+ `callback_on_step_end_tensor_inputs`.
535
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
536
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
537
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
538
+ `._callback_tensor_inputs` attribute of your pipeline class.
539
+ max_sequence_length (`int`, defaults to `224`):
540
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
541
+ Examples:
542
+
543
+ Returns:
544
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
545
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
546
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
547
+ """
548
+
549
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
550
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
551
+
552
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
553
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
554
+
555
+ original_size = original_size or (height, width)
556
+ target_size = (height, width)
557
+
558
+ # Check inputs. Raise error if not correct
559
+ self.check_inputs(
560
+ prompt,
561
+ height,
562
+ width,
563
+ negative_prompt,
564
+ callback_on_step_end_tensor_inputs,
565
+ prompt_embeds,
566
+ negative_prompt_embeds,
567
+ )
568
+ self._guidance_scale = guidance_scale
569
+ self._attention_kwargs = attention_kwargs
570
+ self._current_timestep = None
571
+ self._interrupt = False
572
+
573
+ # Default call parameters
574
+ if prompt is not None and isinstance(prompt, str):
575
+ batch_size = 1
576
+ elif prompt is not None and isinstance(prompt, list):
577
+ batch_size = len(prompt)
578
+ else:
579
+ batch_size = prompt_embeds.shape[0]
580
+
581
+ device = self._execution_device
582
+
583
+ # Encode input prompt
584
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
585
+ prompt,
586
+ negative_prompt,
587
+ self.do_classifier_free_guidance,
588
+ num_images_per_prompt=num_images_per_prompt,
589
+ prompt_embeds=prompt_embeds,
590
+ negative_prompt_embeds=negative_prompt_embeds,
591
+ max_sequence_length=max_sequence_length,
592
+ device=device,
593
+ )
594
+
595
+ # Prepare latents
596
+ latent_channels = self.transformer.config.in_channels // 2
597
+
598
+ control_image = self.prepare_image(
599
+ image=control_image,
600
+ width=width,
601
+ height=height,
602
+ batch_size=batch_size * num_images_per_prompt,
603
+ num_images_per_prompt=num_images_per_prompt,
604
+ device=device,
605
+ dtype=self.vae.dtype,
606
+ )
607
+ height, width = control_image.shape[-2:]
608
+
609
+ vae_shift_factor = 0
610
+
611
+ control_image = self.vae.encode(control_image).latent_dist.sample()
612
+ control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
613
+
614
+ latents = self.prepare_latents(
615
+ batch_size * num_images_per_prompt,
616
+ latent_channels,
617
+ height,
618
+ width,
619
+ torch.float32,
620
+ device,
621
+ generator,
622
+ latents,
623
+ )
624
+
625
+ # Prepare additional timestep conditions
626
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
627
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
628
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
629
+
630
+ original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
631
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
632
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
633
+
634
+ # Prepare timesteps
635
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
636
+ self.transformer.config.patch_size**2
637
+ )
638
+
639
+ timesteps = (
640
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
641
+ if timesteps is None
642
+ else np.array(timesteps)
643
+ )
644
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
645
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
646
+ mu = calculate_shift(
647
+ image_seq_len,
648
+ self.scheduler.config.get("base_image_seq_len", 256),
649
+ self.scheduler.config.get("base_shift", 0.25),
650
+ self.scheduler.config.get("max_shift", 0.75),
651
+ )
652
+ timesteps, num_inference_steps = retrieve_timesteps(
653
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
654
+ )
655
+ self._num_timesteps = len(timesteps)
656
+ # Denoising loop
657
+ transformer_dtype = self.transformer.dtype
658
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
659
+
660
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
661
+ for i, t in enumerate(timesteps):
662
+ if self.interrupt:
663
+ continue
664
+
665
+ self._current_timestep = t
666
+ latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
667
+
668
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
669
+ timestep = t.expand(latents.shape[0])
670
+
671
+ noise_pred_cond = self.transformer(
672
+ hidden_states=latent_model_input,
673
+ encoder_hidden_states=prompt_embeds,
674
+ timestep=timestep,
675
+ original_size=original_size,
676
+ target_size=target_size,
677
+ crop_coords=crops_coords_top_left,
678
+ attention_kwargs=attention_kwargs,
679
+ return_dict=False,
680
+ )[0]
681
+
682
+ # perform guidance
683
+ if self.do_classifier_free_guidance:
684
+ noise_pred_uncond = self.transformer(
685
+ hidden_states=latent_model_input,
686
+ encoder_hidden_states=negative_prompt_embeds,
687
+ timestep=timestep,
688
+ original_size=original_size,
689
+ target_size=target_size,
690
+ crop_coords=crops_coords_top_left,
691
+ attention_kwargs=attention_kwargs,
692
+ return_dict=False,
693
+ )[0]
694
+
695
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
696
+ else:
697
+ noise_pred = noise_pred_cond
698
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
699
+
700
+ # call the callback, if provided
701
+ if callback_on_step_end is not None:
702
+ callback_kwargs = {}
703
+ for k in callback_on_step_end_tensor_inputs:
704
+ callback_kwargs[k] = locals()[k]
705
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
706
+ latents = callback_outputs.pop("latents", latents)
707
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
708
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
709
+
710
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
711
+ progress_bar.update()
712
+
713
+ if XLA_AVAILABLE:
714
+ xm.mark_step()
715
+
716
+ self._current_timestep = None
717
+
718
+ if not output_type == "latent":
719
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
720
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
721
+ else:
722
+ image = latents
723
+
724
+ image = self.image_processor.postprocess(image, output_type=output_type)
725
+
726
+ # Offload all models
727
+ self.maybe_free_model_hooks()
728
+
729
+ if not return_dict:
730
+ return (image,)
731
+
732
+ return CogView4PipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/cogview4/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class CogView4PipelineOutput(BaseOutput):
12
+ """
13
+ Output class for CogView3 pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_opencv_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+
18
+ try:
19
+ if not (is_transformers_available() and is_torch_available() and is_opencv_available()):
20
+ raise OptionalDependencyNotAvailable()
21
+ except OptionalDependencyNotAvailable:
22
+ from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
23
+
24
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
25
+ else:
26
+ _import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
27
+
28
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29
+ try:
30
+ if not (is_transformers_available() and is_torch_available()):
31
+ raise OptionalDependencyNotAvailable()
32
+
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+ else:
36
+ from .pipeline_consisid import ConsisIDPipeline
37
+
38
+ else:
39
+ import sys
40
+
41
+ sys.modules[__name__] = _LazyModule(
42
+ __name__,
43
+ globals()["__file__"],
44
+ _import_structure,
45
+ module_spec=__spec__,
46
+ )
47
+
48
+ for name, value in _dummy_objects.items():
49
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/consisid_utils.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageOps
8
+ from torchvision.transforms import InterpolationMode
9
+ from torchvision.transforms.functional import normalize, resize
10
+
11
+ from ...utils import get_logger, load_image
12
+
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ _insightface_available = importlib.util.find_spec("insightface") is not None
17
+ _consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None
18
+ _facexlib_available = importlib.util.find_spec("facexlib") is not None
19
+
20
+ if _insightface_available:
21
+ import insightface
22
+ from insightface.app import FaceAnalysis
23
+ else:
24
+ raise ImportError("insightface is not available. Please install it using 'pip install insightface'.")
25
+
26
+ if _consisid_eva_clip_available:
27
+ from consisid_eva_clip import create_model_and_transforms
28
+ from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
29
+ else:
30
+ raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.")
31
+
32
+ if _facexlib_available:
33
+ from facexlib.parsing import init_parsing_model
34
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
35
+ else:
36
+ raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.")
37
+
38
+
39
+ def resize_numpy_image_long(image, resize_long_edge=768):
40
+ """
41
+ Resize the input image to a specified long edge while maintaining aspect ratio.
42
+
43
+ Args:
44
+ image (numpy.ndarray): Input image (H x W x C or H x W).
45
+ resize_long_edge (int): The target size for the long edge of the image. Default is 768.
46
+
47
+ Returns:
48
+ numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect
49
+ ratio.
50
+ """
51
+
52
+ h, w = image.shape[:2]
53
+ if max(h, w) <= resize_long_edge:
54
+ return image
55
+ k = resize_long_edge / max(h, w)
56
+ h = int(h * k)
57
+ w = int(w * k)
58
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
59
+ return image
60
+
61
+
62
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
63
+ """Numpy array to tensor.
64
+
65
+ Args:
66
+ imgs (list[ndarray] | ndarray): Input images.
67
+ bgr2rgb (bool): Whether to change bgr to rgb.
68
+ float32 (bool): Whether to change to float32.
69
+
70
+ Returns:
71
+ list[tensor] | tensor: Tensor images. If returned results only have
72
+ one element, just return tensor.
73
+ """
74
+
75
+ def _totensor(img, bgr2rgb, float32):
76
+ if img.shape[2] == 3 and bgr2rgb:
77
+ if img.dtype == "float64":
78
+ img = img.astype("float32")
79
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
80
+ img = torch.from_numpy(img.transpose(2, 0, 1))
81
+ if float32:
82
+ img = img.float()
83
+ return img
84
+
85
+ if isinstance(imgs, list):
86
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
87
+ return _totensor(imgs, bgr2rgb, float32)
88
+
89
+
90
+ def to_gray(img):
91
+ """
92
+ Converts an RGB image to grayscale by applying the standard luminosity formula.
93
+
94
+ Args:
95
+ img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width).
96
+ The image is expected to be in RGB format (3 channels).
97
+
98
+ Returns:
99
+ torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width).
100
+ The grayscale values are replicated across all three channels.
101
+ """
102
+ x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
103
+ x = x.repeat(1, 3, 1, 1)
104
+ return x
105
+
106
+
107
+ def process_face_embeddings(
108
+ face_helper_1,
109
+ clip_vision_model,
110
+ face_helper_2,
111
+ eva_transform_mean,
112
+ eva_transform_std,
113
+ app,
114
+ device,
115
+ weight_dtype,
116
+ image,
117
+ original_id_image=None,
118
+ is_align_face=True,
119
+ ):
120
+ """
121
+ Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed
122
+ face features using a series of face detection and alignment tools.
123
+
124
+ Args:
125
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
126
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
127
+ face_helper_2: Face helper object (second helper) for embedding extraction.
128
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
129
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
130
+ app: Application instance used for face detection.
131
+ device: Device (CPU or GPU) where the computations will be performed.
132
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
133
+ image: Input image in RGB format with pixel values in the range [0, 255].
134
+ original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False.
135
+ is_align_face: Boolean flag indicating whether face alignment should be performed.
136
+
137
+ Returns:
138
+ Tuple:
139
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding
140
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
141
+ - return_face_features_image_2: Processed face features image after normalization and parsing.
142
+ - face_kps: Keypoints of the face detected in the image.
143
+ """
144
+
145
+ face_helper_1.clean_all()
146
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
147
+ # get antelopev2 embedding
148
+ face_info = app.get(image_bgr)
149
+ if len(face_info) > 0:
150
+ face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[
151
+ -1
152
+ ] # only use the maximum face
153
+ id_ante_embedding = face_info["embedding"] # (512,)
154
+ face_kps = face_info["kps"]
155
+ else:
156
+ id_ante_embedding = None
157
+ face_kps = None
158
+
159
+ # using facexlib to detect and align face
160
+ face_helper_1.read_image(image_bgr)
161
+ face_helper_1.get_face_landmarks_5(only_center_face=True)
162
+ if face_kps is None:
163
+ face_kps = face_helper_1.all_landmarks_5[0]
164
+ face_helper_1.align_warp_face()
165
+ if len(face_helper_1.cropped_faces) == 0:
166
+ raise RuntimeError("facexlib align face fail")
167
+ align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
168
+
169
+ # in case insightface didn't detect face
170
+ if id_ante_embedding is None:
171
+ logger.warning("Failed to detect face using insightface. Extracting embedding with align face")
172
+ id_ante_embedding = face_helper_2.get_feat(align_face)
173
+
174
+ id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
175
+ if id_ante_embedding.ndim == 1:
176
+ id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512])
177
+
178
+ # parsing
179
+ if is_align_face:
180
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
181
+ input = input.to(device)
182
+ parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
183
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
184
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
185
+ bg = sum(parsing_out == i for i in bg_label).bool()
186
+ white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512])
187
+ # only keep the face features
188
+ return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512])
189
+ return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512])
190
+ else:
191
+ original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR)
192
+ input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
193
+ input = input.to(device)
194
+ return_face_features_image = return_face_features_image_2 = input
195
+
196
+ # transform img before sending to eva-clip-vit
197
+ face_features_image = resize(
198
+ return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC
199
+ ) # torch.Size([1, 3, 336, 336])
200
+ face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
201
+ id_cond_vit, id_vit_hidden = clip_vision_model(
202
+ face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
203
+ ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
204
+ id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
205
+ id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
206
+
207
+ id_cond = torch.cat(
208
+ [id_ante_embedding, id_cond_vit], dim=-1
209
+ ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
210
+
211
+ return (
212
+ id_cond,
213
+ id_vit_hidden,
214
+ return_face_features_image_2,
215
+ face_kps,
216
+ ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
217
+
218
+
219
+ def process_face_embeddings_infer(
220
+ face_helper_1,
221
+ clip_vision_model,
222
+ face_helper_2,
223
+ eva_transform_mean,
224
+ eva_transform_std,
225
+ app,
226
+ device,
227
+ weight_dtype,
228
+ img_file_path,
229
+ is_align_face=True,
230
+ ):
231
+ """
232
+ Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding
233
+ concatenation.
234
+
235
+ Args:
236
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
237
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
238
+ face_helper_2: Face helper object (second helper) for embedding extraction.
239
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
240
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
241
+ app: Application instance used for face detection.
242
+ device: Device (CPU or GPU) where the computations will be performed.
243
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
244
+ img_file_path: Path to the input image file (string) or a numpy array representing an image.
245
+ is_align_face: Boolean flag indicating whether face alignment should be performed (default: True).
246
+
247
+ Returns:
248
+ Tuple:
249
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding.
250
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
251
+ - image: Processed face image after feature extraction and alignment.
252
+ - face_kps: Keypoints of the face detected in the image.
253
+ """
254
+
255
+ # Load and preprocess the input image
256
+ if isinstance(img_file_path, str):
257
+ image = np.array(load_image(image=img_file_path).convert("RGB"))
258
+ else:
259
+ image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
260
+
261
+ # Resize image to ensure the longer side is 1024 pixels
262
+ image = resize_numpy_image_long(image, 1024)
263
+ original_id_image = image
264
+
265
+ # Process the image to extract face embeddings and related features
266
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(
267
+ face_helper_1,
268
+ clip_vision_model,
269
+ face_helper_2,
270
+ eva_transform_mean,
271
+ eva_transform_std,
272
+ app,
273
+ device,
274
+ weight_dtype,
275
+ image,
276
+ original_id_image,
277
+ is_align_face,
278
+ )
279
+
280
+ # Convert the aligned cropped face image (torch tensor) to a numpy array
281
+ tensor = align_crop_face_image.cpu().detach()
282
+ tensor = tensor.squeeze()
283
+ tensor = tensor.permute(1, 2, 0)
284
+ tensor = tensor.numpy() * 255
285
+ tensor = tensor.astype(np.uint8)
286
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
287
+
288
+ return id_cond, id_vit_hidden, image, face_kps
289
+
290
+
291
+ def prepare_face_models(model_path, device, dtype):
292
+ """
293
+ Prepare all face models for the facial recognition task.
294
+
295
+ Parameters:
296
+ - model_path: Path to the directory containing model files.
297
+ - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
298
+ - dtype: Data type (e.g., torch.float32) for model inference.
299
+
300
+ Returns:
301
+ - face_helper_1: First face restoration helper.
302
+ - face_helper_2: Second face restoration helper.
303
+ - face_clip_model: CLIP model for face extraction.
304
+ - eva_transform_mean: Mean value for image normalization.
305
+ - eva_transform_std: Standard deviation value for image normalization.
306
+ - face_main_model: Main face analysis model.
307
+ """
308
+ # get helper model
309
+ face_helper_1 = FaceRestoreHelper(
310
+ upscale_factor=1,
311
+ face_size=512,
312
+ crop_ratio=(1, 1),
313
+ det_model="retinaface_resnet50",
314
+ save_ext="png",
315
+ device=device,
316
+ model_rootpath=os.path.join(model_path, "face_encoder"),
317
+ )
318
+ face_helper_1.face_parse = None
319
+ face_helper_1.face_parse = init_parsing_model(
320
+ model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder")
321
+ )
322
+ face_helper_2 = insightface.model_zoo.get_model(
323
+ f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"]
324
+ )
325
+ face_helper_2.prepare(ctx_id=0)
326
+
327
+ # get local facial extractor part 1
328
+ model, _, _ = create_model_and_transforms(
329
+ "EVA02-CLIP-L-14-336",
330
+ os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"),
331
+ force_custom_clip=True,
332
+ )
333
+ face_clip_model = model.visual
334
+ eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN)
335
+ eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD)
336
+ if not isinstance(eva_transform_mean, (list, tuple)):
337
+ eva_transform_mean = (eva_transform_mean,) * 3
338
+ if not isinstance(eva_transform_std, (list, tuple)):
339
+ eva_transform_std = (eva_transform_std,) * 3
340
+ eva_transform_mean = eva_transform_mean
341
+ eva_transform_std = eva_transform_std
342
+
343
+ # get local facial extractor part 2
344
+ face_main_model = FaceAnalysis(
345
+ name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"]
346
+ )
347
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
348
+
349
+ # move face models to device
350
+ face_helper_1.face_det.eval()
351
+ face_helper_1.face_parse.eval()
352
+ face_clip_model.eval()
353
+ face_helper_1.face_det.to(device)
354
+ face_helper_1.face_parse.to(device)
355
+ face_clip_model.to(device, dtype=dtype)
356
+
357
+ return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_consisid.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...image_processor import PipelineImageInput
26
+ from ...loaders import CogVideoXLoraLoaderMixin
27
+ from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
28
+ from ...models.embeddings import get_3d_rotary_pos_embed
29
+ from ...pipelines.pipeline_utils import DiffusionPipeline
30
+ from ...schedulers import CogVideoXDPMScheduler
31
+ from ...utils import is_opencv_available, logging, replace_example_docstring
32
+ from ...utils.torch_utils import randn_tensor
33
+ from ...video_processor import VideoProcessor
34
+ from .pipeline_output import ConsisIDPipelineOutput
35
+
36
+
37
+ if is_opencv_available():
38
+ import cv2
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```python
47
+ >>> import torch
48
+ >>> from diffusers import ConsisIDPipeline
49
+ >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
50
+ >>> from diffusers.utils import export_to_video
51
+ >>> from huggingface_hub import snapshot_download
52
+
53
+ >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
54
+ >>> (
55
+ ... face_helper_1,
56
+ ... face_helper_2,
57
+ ... face_clip_model,
58
+ ... face_main_model,
59
+ ... eva_transform_mean,
60
+ ... eva_transform_std,
61
+ ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
62
+ >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
63
+ >>> pipe.to("cuda")
64
+
65
+ >>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body).
66
+ >>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
67
+ >>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
68
+
69
+ >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
70
+ ... face_helper_1,
71
+ ... face_clip_model,
72
+ ... face_helper_2,
73
+ ... eva_transform_mean,
74
+ ... eva_transform_std,
75
+ ... face_main_model,
76
+ ... "cuda",
77
+ ... torch.bfloat16,
78
+ ... image,
79
+ ... is_align_face=True,
80
+ ... )
81
+
82
+ >>> video = pipe(
83
+ ... image=image,
84
+ ... prompt=prompt,
85
+ ... num_inference_steps=50,
86
+ ... guidance_scale=6.0,
87
+ ... use_dynamic_cfg=False,
88
+ ... id_vit_hidden=id_vit_hidden,
89
+ ... id_cond=id_cond,
90
+ ... kps_cond=face_kps,
91
+ ... generator=torch.Generator("cuda").manual_seed(42),
92
+ ... )
93
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
94
+ ```
95
+ """
96
+
97
+
98
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
99
+ """
100
+ This function draws keypoints and the limbs connecting them on an image.
101
+
102
+ Parameters:
103
+ - image_pil (PIL.Image): Input image as a PIL object.
104
+ - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates.
105
+ - color_list (list of tuples, optional): List of colors (in RGB format) for each keypoint. Default is a set of five
106
+ colors.
107
+
108
+ Returns:
109
+ - PIL.Image: Image with the keypoints and limbs drawn.
110
+ """
111
+
112
+ stickwidth = 4
113
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
114
+ kps = np.array(kps)
115
+
116
+ w, h = image_pil.size
117
+ out_img = np.zeros([h, w, 3])
118
+
119
+ for i in range(len(limbSeq)):
120
+ index = limbSeq[i]
121
+ color = color_list[index[0]]
122
+
123
+ x = kps[index][:, 0]
124
+ y = kps[index][:, 1]
125
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
126
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
127
+ polygon = cv2.ellipse2Poly(
128
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
129
+ )
130
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
131
+ out_img = (out_img * 0.6).astype(np.uint8)
132
+
133
+ for idx_kp, kp in enumerate(kps):
134
+ color = color_list[idx_kp]
135
+ x, y = kp
136
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
137
+
138
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
139
+ return out_img_pil
140
+
141
+
142
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
143
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
144
+ """
145
+ This function calculates the resize and crop region for an image to fit a target width and height while preserving
146
+ the aspect ratio.
147
+
148
+ Parameters:
149
+ - src (tuple): A tuple containing the source image's height (h) and width (w).
150
+ - tgt_width (int): The target width to resize the image.
151
+ - tgt_height (int): The target height to resize the image.
152
+
153
+ Returns:
154
+ - tuple: Two tuples representing the crop region:
155
+ 1. The top-left coordinates of the crop region.
156
+ 2. The bottom-right coordinates of the crop region.
157
+ """
158
+
159
+ tw = tgt_width
160
+ th = tgt_height
161
+ h, w = src
162
+ r = h / w
163
+ if r > (th / tw):
164
+ resize_height = th
165
+ resize_width = int(round(th / h * w))
166
+ else:
167
+ resize_width = tw
168
+ resize_height = int(round(tw / w * h))
169
+
170
+ crop_top = int(round((th - resize_height) / 2.0))
171
+ crop_left = int(round((tw - resize_width) / 2.0))
172
+
173
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
174
+
175
+
176
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
177
+ def retrieve_timesteps(
178
+ scheduler,
179
+ num_inference_steps: Optional[int] = None,
180
+ device: Optional[Union[str, torch.device]] = None,
181
+ timesteps: Optional[List[int]] = None,
182
+ sigmas: Optional[List[float]] = None,
183
+ **kwargs,
184
+ ):
185
+ r"""
186
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
187
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
188
+
189
+ Args:
190
+ scheduler (`SchedulerMixin`):
191
+ The scheduler to get timesteps from.
192
+ num_inference_steps (`int`):
193
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
194
+ must be `None`.
195
+ device (`str` or `torch.device`, *optional*):
196
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
197
+ timesteps (`List[int]`, *optional*):
198
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
199
+ `num_inference_steps` and `sigmas` must be `None`.
200
+ sigmas (`List[float]`, *optional*):
201
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
202
+ `num_inference_steps` and `timesteps` must be `None`.
203
+
204
+ Returns:
205
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
206
+ second element is the number of inference steps.
207
+ """
208
+ if timesteps is not None and sigmas is not None:
209
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
210
+ if timesteps is not None:
211
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
212
+ if not accepts_timesteps:
213
+ raise ValueError(
214
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
215
+ f" timestep schedules. Please check whether you are using the correct scheduler."
216
+ )
217
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
218
+ timesteps = scheduler.timesteps
219
+ num_inference_steps = len(timesteps)
220
+ elif sigmas is not None:
221
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
222
+ if not accept_sigmas:
223
+ raise ValueError(
224
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
225
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
226
+ )
227
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
228
+ timesteps = scheduler.timesteps
229
+ num_inference_steps = len(timesteps)
230
+ else:
231
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
232
+ timesteps = scheduler.timesteps
233
+ return timesteps, num_inference_steps
234
+
235
+
236
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
237
+ def retrieve_latents(
238
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
239
+ ):
240
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
241
+ return encoder_output.latent_dist.sample(generator)
242
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
243
+ return encoder_output.latent_dist.mode()
244
+ elif hasattr(encoder_output, "latents"):
245
+ return encoder_output.latents
246
+ else:
247
+ raise AttributeError("Could not access latents of provided encoder_output")
248
+
249
+
250
+ class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
251
+ r"""
252
+ Pipeline for image-to-video generation using ConsisID.
253
+
254
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
255
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
256
+
257
+ Args:
258
+ vae ([`AutoencoderKL`]):
259
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
260
+ text_encoder ([`T5EncoderModel`]):
261
+ Frozen text-encoder. ConsisID uses
262
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
263
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
264
+ tokenizer (`T5Tokenizer`):
265
+ Tokenizer of class
266
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
267
+ transformer ([`ConsisIDTransformer3DModel`]):
268
+ A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents.
269
+ scheduler ([`SchedulerMixin`]):
270
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
271
+ """
272
+
273
+ _optional_components = []
274
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
275
+
276
+ _callback_tensor_inputs = [
277
+ "latents",
278
+ "prompt_embeds",
279
+ "negative_prompt_embeds",
280
+ ]
281
+
282
+ def __init__(
283
+ self,
284
+ tokenizer: T5Tokenizer,
285
+ text_encoder: T5EncoderModel,
286
+ vae: AutoencoderKLCogVideoX,
287
+ transformer: ConsisIDTransformer3DModel,
288
+ scheduler: CogVideoXDPMScheduler,
289
+ ):
290
+ super().__init__()
291
+
292
+ self.register_modules(
293
+ tokenizer=tokenizer,
294
+ text_encoder=text_encoder,
295
+ vae=vae,
296
+ transformer=transformer,
297
+ scheduler=scheduler,
298
+ )
299
+ self.vae_scale_factor_spatial = (
300
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
301
+ )
302
+ self.vae_scale_factor_temporal = (
303
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
304
+ )
305
+ self.vae_scaling_factor_image = (
306
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
307
+ )
308
+
309
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
310
+
311
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
312
+ def _get_t5_prompt_embeds(
313
+ self,
314
+ prompt: Union[str, List[str]] = None,
315
+ num_videos_per_prompt: int = 1,
316
+ max_sequence_length: int = 226,
317
+ device: Optional[torch.device] = None,
318
+ dtype: Optional[torch.dtype] = None,
319
+ ):
320
+ device = device or self._execution_device
321
+ dtype = dtype or self.text_encoder.dtype
322
+
323
+ prompt = [prompt] if isinstance(prompt, str) else prompt
324
+ batch_size = len(prompt)
325
+
326
+ text_inputs = self.tokenizer(
327
+ prompt,
328
+ padding="max_length",
329
+ max_length=max_sequence_length,
330
+ truncation=True,
331
+ add_special_tokens=True,
332
+ return_tensors="pt",
333
+ )
334
+ text_input_ids = text_inputs.input_ids
335
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
336
+
337
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
338
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
339
+ logger.warning(
340
+ "The following part of your input was truncated because `max_sequence_length` is set to "
341
+ f" {max_sequence_length} tokens: {removed_text}"
342
+ )
343
+
344
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
345
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
346
+
347
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
348
+ _, seq_len, _ = prompt_embeds.shape
349
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
350
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
351
+
352
+ return prompt_embeds
353
+
354
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
355
+ def encode_prompt(
356
+ self,
357
+ prompt: Union[str, List[str]],
358
+ negative_prompt: Optional[Union[str, List[str]]] = None,
359
+ do_classifier_free_guidance: bool = True,
360
+ num_videos_per_prompt: int = 1,
361
+ prompt_embeds: Optional[torch.Tensor] = None,
362
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
363
+ max_sequence_length: int = 226,
364
+ device: Optional[torch.device] = None,
365
+ dtype: Optional[torch.dtype] = None,
366
+ ):
367
+ r"""
368
+ Encodes the prompt into text encoder hidden states.
369
+
370
+ Args:
371
+ prompt (`str` or `List[str]`, *optional*):
372
+ prompt to be encoded
373
+ negative_prompt (`str` or `List[str]`, *optional*):
374
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
375
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
376
+ less than `1`).
377
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
378
+ Whether to use classifier free guidance or not.
379
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
380
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
381
+ prompt_embeds (`torch.Tensor`, *optional*):
382
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
383
+ provided, text embeddings will be generated from `prompt` input argument.
384
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
385
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
386
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
387
+ argument.
388
+ device: (`torch.device`, *optional*):
389
+ torch device
390
+ dtype: (`torch.dtype`, *optional*):
391
+ torch dtype
392
+ """
393
+ device = device or self._execution_device
394
+
395
+ prompt = [prompt] if isinstance(prompt, str) else prompt
396
+ if prompt is not None:
397
+ batch_size = len(prompt)
398
+ else:
399
+ batch_size = prompt_embeds.shape[0]
400
+
401
+ if prompt_embeds is None:
402
+ prompt_embeds = self._get_t5_prompt_embeds(
403
+ prompt=prompt,
404
+ num_videos_per_prompt=num_videos_per_prompt,
405
+ max_sequence_length=max_sequence_length,
406
+ device=device,
407
+ dtype=dtype,
408
+ )
409
+
410
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
411
+ negative_prompt = negative_prompt or ""
412
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
413
+
414
+ if prompt is not None and type(prompt) is not type(negative_prompt):
415
+ raise TypeError(
416
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
417
+ f" {type(prompt)}."
418
+ )
419
+ elif batch_size != len(negative_prompt):
420
+ raise ValueError(
421
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
422
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
423
+ " the batch size of `prompt`."
424
+ )
425
+
426
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
427
+ prompt=negative_prompt,
428
+ num_videos_per_prompt=num_videos_per_prompt,
429
+ max_sequence_length=max_sequence_length,
430
+ device=device,
431
+ dtype=dtype,
432
+ )
433
+
434
+ return prompt_embeds, negative_prompt_embeds
435
+
436
+ def prepare_latents(
437
+ self,
438
+ image: torch.Tensor,
439
+ batch_size: int = 1,
440
+ num_channels_latents: int = 16,
441
+ num_frames: int = 13,
442
+ height: int = 60,
443
+ width: int = 90,
444
+ dtype: Optional[torch.dtype] = None,
445
+ device: Optional[torch.device] = None,
446
+ generator: Optional[torch.Generator] = None,
447
+ latents: Optional[torch.Tensor] = None,
448
+ kps_cond: Optional[torch.Tensor] = None,
449
+ ):
450
+ if isinstance(generator, list) and len(generator) != batch_size:
451
+ raise ValueError(
452
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
453
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
454
+ )
455
+
456
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
457
+ shape = (
458
+ batch_size,
459
+ num_frames,
460
+ num_channels_latents,
461
+ height // self.vae_scale_factor_spatial,
462
+ width // self.vae_scale_factor_spatial,
463
+ )
464
+
465
+ image = image.unsqueeze(2) # [B, C, F, H, W]
466
+
467
+ if isinstance(generator, list):
468
+ image_latents = [
469
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
470
+ ]
471
+ if kps_cond is not None:
472
+ kps_cond = kps_cond.unsqueeze(2)
473
+ kps_cond_latents = [
474
+ retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
475
+ for i in range(batch_size)
476
+ ]
477
+ else:
478
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
479
+ if kps_cond is not None:
480
+ kps_cond = kps_cond.unsqueeze(2)
481
+ kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond]
482
+
483
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
484
+ image_latents = self.vae_scaling_factor_image * image_latents
485
+
486
+ if kps_cond is not None:
487
+ kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
488
+ kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents
489
+
490
+ padding_shape = (
491
+ batch_size,
492
+ num_frames - 2,
493
+ num_channels_latents,
494
+ height // self.vae_scale_factor_spatial,
495
+ width // self.vae_scale_factor_spatial,
496
+ )
497
+ else:
498
+ padding_shape = (
499
+ batch_size,
500
+ num_frames - 1,
501
+ num_channels_latents,
502
+ height // self.vae_scale_factor_spatial,
503
+ width // self.vae_scale_factor_spatial,
504
+ )
505
+
506
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
507
+ if kps_cond is not None:
508
+ image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1)
509
+ else:
510
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
511
+
512
+ if latents is None:
513
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
514
+ else:
515
+ latents = latents.to(device)
516
+
517
+ # scale the initial noise by the standard deviation required by the scheduler
518
+ latents = latents * self.scheduler.init_noise_sigma
519
+ return latents, image_latents
520
+
521
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
522
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
523
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
524
+ latents = 1 / self.vae_scaling_factor_image * latents
525
+
526
+ frames = self.vae.decode(latents).sample
527
+ return frames
528
+
529
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
530
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
531
+ # get the original timestep using init_timestep
532
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
533
+
534
+ t_start = max(num_inference_steps - init_timestep, 0)
535
+ timesteps = timesteps[t_start * self.scheduler.order :]
536
+
537
+ return timesteps, num_inference_steps - t_start
538
+
539
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
540
+ def prepare_extra_step_kwargs(self, generator, eta):
541
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
542
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
543
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
544
+ # and should be between [0, 1]
545
+
546
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
547
+ extra_step_kwargs = {}
548
+ if accepts_eta:
549
+ extra_step_kwargs["eta"] = eta
550
+
551
+ # check if the scheduler accepts generator
552
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
553
+ if accepts_generator:
554
+ extra_step_kwargs["generator"] = generator
555
+ return extra_step_kwargs
556
+
557
+ def check_inputs(
558
+ self,
559
+ image,
560
+ prompt,
561
+ height,
562
+ width,
563
+ negative_prompt,
564
+ callback_on_step_end_tensor_inputs,
565
+ latents=None,
566
+ prompt_embeds=None,
567
+ negative_prompt_embeds=None,
568
+ ):
569
+ if (
570
+ not isinstance(image, torch.Tensor)
571
+ and not isinstance(image, PIL.Image.Image)
572
+ and not isinstance(image, list)
573
+ ):
574
+ raise ValueError(
575
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
576
+ f" {type(image)}"
577
+ )
578
+
579
+ if height % 8 != 0 or width % 8 != 0:
580
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
581
+
582
+ if callback_on_step_end_tensor_inputs is not None and not all(
583
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
584
+ ):
585
+ raise ValueError(
586
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
587
+ )
588
+ if prompt is not None and prompt_embeds is not None:
589
+ raise ValueError(
590
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
591
+ " only forward one of the two."
592
+ )
593
+ elif prompt is None and prompt_embeds is None:
594
+ raise ValueError(
595
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
596
+ )
597
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
598
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
599
+
600
+ if prompt is not None and negative_prompt_embeds is not None:
601
+ raise ValueError(
602
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
603
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
604
+ )
605
+
606
+ if negative_prompt is not None and negative_prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
609
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
610
+ )
611
+
612
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
613
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
614
+ raise ValueError(
615
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
616
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
617
+ f" {negative_prompt_embeds.shape}."
618
+ )
619
+
620
+ def _prepare_rotary_positional_embeddings(
621
+ self,
622
+ height: int,
623
+ width: int,
624
+ num_frames: int,
625
+ device: torch.device,
626
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
627
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
628
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
629
+ base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size
630
+ base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size
631
+
632
+ grid_crops_coords = get_resize_crop_region_for_grid(
633
+ (grid_height, grid_width), base_size_width, base_size_height
634
+ )
635
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
636
+ embed_dim=self.transformer.config.attention_head_dim,
637
+ crops_coords=grid_crops_coords,
638
+ grid_size=(grid_height, grid_width),
639
+ temporal_size=num_frames,
640
+ device=device,
641
+ )
642
+
643
+ return freqs_cos, freqs_sin
644
+
645
+ @property
646
+ def guidance_scale(self):
647
+ return self._guidance_scale
648
+
649
+ @property
650
+ def num_timesteps(self):
651
+ return self._num_timesteps
652
+
653
+ @property
654
+ def attention_kwargs(self):
655
+ return self._attention_kwargs
656
+
657
+ @property
658
+ def interrupt(self):
659
+ return self._interrupt
660
+
661
+ @torch.no_grad()
662
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
663
+ def __call__(
664
+ self,
665
+ image: PipelineImageInput,
666
+ prompt: Optional[Union[str, List[str]]] = None,
667
+ negative_prompt: Optional[Union[str, List[str]]] = None,
668
+ height: int = 480,
669
+ width: int = 720,
670
+ num_frames: int = 49,
671
+ num_inference_steps: int = 50,
672
+ guidance_scale: float = 6.0,
673
+ use_dynamic_cfg: bool = False,
674
+ num_videos_per_prompt: int = 1,
675
+ eta: float = 0.0,
676
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
677
+ latents: Optional[torch.FloatTensor] = None,
678
+ prompt_embeds: Optional[torch.FloatTensor] = None,
679
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
680
+ output_type: str = "pil",
681
+ return_dict: bool = True,
682
+ attention_kwargs: Optional[Dict[str, Any]] = None,
683
+ callback_on_step_end: Optional[
684
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
685
+ ] = None,
686
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
687
+ max_sequence_length: int = 226,
688
+ id_vit_hidden: Optional[torch.Tensor] = None,
689
+ id_cond: Optional[torch.Tensor] = None,
690
+ kps_cond: Optional[torch.Tensor] = None,
691
+ ) -> Union[ConsisIDPipelineOutput, Tuple]:
692
+ """
693
+ Function invoked when calling the pipeline for generation.
694
+
695
+ Args:
696
+ image (`PipelineImageInput`):
697
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
698
+ prompt (`str` or `List[str]`, *optional*):
699
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
700
+ instead.
701
+ negative_prompt (`str` or `List[str]`, *optional*):
702
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
703
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
704
+ less than `1`).
705
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
706
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
707
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
708
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
709
+ num_frames (`int`, defaults to `49`):
710
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
711
+ contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
712
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
713
+ needs to be satisfied is that of divisibility mentioned above.
714
+ num_inference_steps (`int`, *optional*, defaults to 50):
715
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
716
+ expense of slower inference.
717
+ guidance_scale (`float`, *optional*, defaults to 6):
718
+ Guidance scale as defined in [Classifier-Free Diffusion
719
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
720
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
721
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
722
+ the text `prompt`, usually at the expense of lower image quality.
723
+ use_dynamic_cfg (`bool`, *optional*, defaults to `False`):
724
+ If True, dynamically adjusts the guidance scale during inference. This allows the model to use a
725
+ progressive guidance scale, improving the balance between text-guided generation and image quality over
726
+ the course of the inference steps. Typically, early inference steps use a higher guidance scale for
727
+ more faithful image generation, while later steps reduce it for more diverse and natural results.
728
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
729
+ The number of videos to generate per prompt.
730
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
731
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
732
+ to make generation deterministic.
733
+ latents (`torch.FloatTensor`, *optional*):
734
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
735
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
736
+ tensor will be generated by sampling using the supplied random `generator`.
737
+ prompt_embeds (`torch.FloatTensor`, *optional*):
738
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
739
+ provided, text embeddings will be generated from `prompt` input argument.
740
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
741
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
742
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
743
+ argument.
744
+ output_type (`str`, *optional*, defaults to `"pil"`):
745
+ The output format of the generate image. Choose between
746
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
747
+ return_dict (`bool`, *optional*, defaults to `True`):
748
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
749
+ of a plain tuple.
750
+ attention_kwargs (`dict`, *optional*):
751
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
752
+ `self.processor` in
753
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
754
+ callback_on_step_end (`Callable`, *optional*):
755
+ A function that calls at the end of each denoising steps during the inference. The function is called
756
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
757
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
758
+ `callback_on_step_end_tensor_inputs`.
759
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
760
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
761
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
762
+ `._callback_tensor_inputs` attribute of your pipeline class.
763
+ max_sequence_length (`int`, defaults to `226`):
764
+ Maximum sequence length in encoded prompt. Must be consistent with
765
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
766
+ id_vit_hidden (`Optional[torch.Tensor]`, *optional*):
767
+ The tensor representing the hidden features extracted from the face model, which are used to condition
768
+ the local facial extractor. This is crucial for the model to obtain high-frequency information of the
769
+ face. If not provided, the local facial extractor will not run normally.
770
+ id_cond (`Optional[torch.Tensor]`, *optional*):
771
+ The tensor representing the hidden features extracted from the clip model, which are used to condition
772
+ the local facial extractor. This is crucial for the model to edit facial features If not provided, the
773
+ local facial extractor will not run normally.
774
+ kps_cond (`Optional[torch.Tensor]`, *optional*):
775
+ A tensor that determines whether the global facial extractor use keypoint information for conditioning.
776
+ If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are
777
+ used during the generation process. This helps ensure the model retains more facial low-frequency
778
+ information.
779
+
780
+ Examples:
781
+
782
+ Returns:
783
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`:
784
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a
785
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
786
+ """
787
+
788
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
789
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
790
+
791
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
792
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
793
+ num_frames = num_frames or self.transformer.config.sample_frames
794
+
795
+ num_videos_per_prompt = 1
796
+
797
+ # 1. Check inputs. Raise error if not correct
798
+ self.check_inputs(
799
+ image=image,
800
+ prompt=prompt,
801
+ height=height,
802
+ width=width,
803
+ negative_prompt=negative_prompt,
804
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
805
+ latents=latents,
806
+ prompt_embeds=prompt_embeds,
807
+ negative_prompt_embeds=negative_prompt_embeds,
808
+ )
809
+ self._guidance_scale = guidance_scale
810
+ self._attention_kwargs = attention_kwargs
811
+ self._interrupt = False
812
+
813
+ # 2. Default call parameters
814
+ if prompt is not None and isinstance(prompt, str):
815
+ batch_size = 1
816
+ elif prompt is not None and isinstance(prompt, list):
817
+ batch_size = len(prompt)
818
+ else:
819
+ batch_size = prompt_embeds.shape[0]
820
+
821
+ device = self._execution_device
822
+
823
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
824
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
825
+ # corresponds to doing no classifier free guidance.
826
+ do_classifier_free_guidance = guidance_scale > 1.0
827
+
828
+ # 3. Encode input prompt
829
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
830
+ prompt=prompt,
831
+ negative_prompt=negative_prompt,
832
+ do_classifier_free_guidance=do_classifier_free_guidance,
833
+ num_videos_per_prompt=num_videos_per_prompt,
834
+ prompt_embeds=prompt_embeds,
835
+ negative_prompt_embeds=negative_prompt_embeds,
836
+ max_sequence_length=max_sequence_length,
837
+ device=device,
838
+ )
839
+ if do_classifier_free_guidance:
840
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
841
+
842
+ # 4. Prepare timesteps
843
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
844
+ self._num_timesteps = len(timesteps)
845
+
846
+ # 5. Prepare latents
847
+ is_kps = getattr(self.transformer.config, "is_kps", False)
848
+ kps_cond = kps_cond if is_kps else None
849
+ if kps_cond is not None:
850
+ kps_cond = draw_kps(image, kps_cond)
851
+ kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
852
+ device, dtype=prompt_embeds.dtype
853
+ )
854
+
855
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
856
+ device, dtype=prompt_embeds.dtype
857
+ )
858
+
859
+ latent_channels = self.transformer.config.in_channels // 2
860
+ latents, image_latents = self.prepare_latents(
861
+ image,
862
+ batch_size * num_videos_per_prompt,
863
+ latent_channels,
864
+ num_frames,
865
+ height,
866
+ width,
867
+ prompt_embeds.dtype,
868
+ device,
869
+ generator,
870
+ latents,
871
+ kps_cond,
872
+ )
873
+
874
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
875
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
876
+
877
+ # 7. Create rotary embeds if required
878
+ image_rotary_emb = (
879
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
880
+ if self.transformer.config.use_rotary_positional_embeddings
881
+ else None
882
+ )
883
+
884
+ # 8. Denoising loop
885
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
886
+
887
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
888
+ # for DPM-solver++
889
+ old_pred_original_sample = None
890
+ timesteps_cpu = timesteps.cpu()
891
+ for i, t in enumerate(timesteps):
892
+ if self.interrupt:
893
+ continue
894
+
895
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
896
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
897
+
898
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
899
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
900
+
901
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
902
+ timestep = t.expand(latent_model_input.shape[0])
903
+
904
+ # predict noise model_output
905
+ noise_pred = self.transformer(
906
+ hidden_states=latent_model_input,
907
+ encoder_hidden_states=prompt_embeds,
908
+ timestep=timestep,
909
+ image_rotary_emb=image_rotary_emb,
910
+ attention_kwargs=attention_kwargs,
911
+ return_dict=False,
912
+ id_vit_hidden=id_vit_hidden,
913
+ id_cond=id_cond,
914
+ )[0]
915
+ noise_pred = noise_pred.float()
916
+
917
+ # perform guidance
918
+ if use_dynamic_cfg:
919
+ self._guidance_scale = 1 + guidance_scale * (
920
+ (
921
+ 1
922
+ - math.cos(
923
+ math.pi
924
+ * ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0
925
+ )
926
+ )
927
+ / 2
928
+ )
929
+ if do_classifier_free_guidance:
930
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
931
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
932
+
933
+ # compute the previous noisy sample x_t -> x_t-1
934
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
935
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
936
+ else:
937
+ latents, old_pred_original_sample = self.scheduler.step(
938
+ noise_pred,
939
+ old_pred_original_sample,
940
+ t,
941
+ timesteps[i - 1] if i > 0 else None,
942
+ latents,
943
+ **extra_step_kwargs,
944
+ return_dict=False,
945
+ )
946
+ latents = latents.to(prompt_embeds.dtype)
947
+
948
+ # call the callback, if provided
949
+ if callback_on_step_end is not None:
950
+ callback_kwargs = {}
951
+ for k in callback_on_step_end_tensor_inputs:
952
+ callback_kwargs[k] = locals()[k]
953
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
954
+
955
+ latents = callback_outputs.pop("latents", latents)
956
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
957
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
958
+
959
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
960
+ progress_bar.update()
961
+
962
+ if not output_type == "latent":
963
+ video = self.decode_latents(latents)
964
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
965
+ else:
966
+ video = latents
967
+
968
+ # Offload all models
969
+ self.maybe_free_model_hooks()
970
+
971
+ if not return_dict:
972
+ return (video,)
973
+
974
+ return ConsisIDPipelineOutput(frames=video)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consisid/pipeline_output.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from diffusers.utils import BaseOutput
6
+
7
+
8
+ @dataclass
9
+ class ConsisIDPipelineOutput(BaseOutput):
10
+ r"""
11
+ Output class for ConsisID pipelines.
12
+
13
+ Args:
14
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17
+ `(batch_size, num_frames, channels, height, width)`.
18
+ """
19
+
20
+ frames: torch.Tensor
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ _LazyModule,
6
+ )
7
+
8
+
9
+ _import_structure = {
10
+ "pipeline_consistency_models": ["ConsistencyModelPipeline"],
11
+ }
12
+
13
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
14
+ from .pipeline_consistency_models import ConsistencyModelPipeline
15
+
16
+ else:
17
+ import sys
18
+
19
+ sys.modules[__name__] = _LazyModule(
20
+ __name__,
21
+ globals()["__file__"],
22
+ _import_structure,
23
+ module_spec=__spec__,
24
+ )
pythonProject/diffusers-main/build/lib/diffusers/pipelines/consistency_models/pipeline_consistency_models.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import torch
18
+
19
+ from ...models import UNet2DModel
20
+ from ...schedulers import CMStochasticIterativeScheduler
21
+ from ...utils import (
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ )
26
+ from ...utils.torch_utils import randn_tensor
27
+ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
28
+
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+
45
+ >>> from diffusers import ConsistencyModelPipeline
46
+
47
+ >>> device = "cuda"
48
+ >>> # Load the cd_imagenet64_l2 checkpoint.
49
+ >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
50
+ >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
51
+ >>> pipe.to(device)
52
+
53
+ >>> # Onestep Sampling
54
+ >>> image = pipe(num_inference_steps=1).images[0]
55
+ >>> image.save("cd_imagenet64_l2_onestep_sample.png")
56
+
57
+ >>> # Onestep sampling, class-conditional image generation
58
+ >>> # ImageNet-64 class label 145 corresponds to king penguins
59
+ >>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
60
+ >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
61
+
62
+ >>> # Multistep sampling, class-conditional image generation
63
+ >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:
64
+ >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
65
+ >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
66
+ >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
67
+ ```
68
+ """
69
+
70
+
71
+ class ConsistencyModelPipeline(DiffusionPipeline):
72
+ r"""
73
+ Pipeline for unconditional or class-conditional image generation.
74
+
75
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
76
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
77
+
78
+ Args:
79
+ unet ([`UNet2DModel`]):
80
+ A `UNet2DModel` to denoise the encoded image latents.
81
+ scheduler ([`SchedulerMixin`]):
82
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
83
+ compatible with [`CMStochasticIterativeScheduler`].
84
+ """
85
+
86
+ model_cpu_offload_seq = "unet"
87
+
88
+ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
89
+ super().__init__()
90
+
91
+ self.register_modules(
92
+ unet=unet,
93
+ scheduler=scheduler,
94
+ )
95
+
96
+ self.safety_checker = None
97
+
98
+ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
99
+ shape = (batch_size, num_channels, height, width)
100
+ if isinstance(generator, list) and len(generator) != batch_size:
101
+ raise ValueError(
102
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
103
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
104
+ )
105
+
106
+ if latents is None:
107
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
108
+ else:
109
+ latents = latents.to(device=device, dtype=dtype)
110
+
111
+ # scale the initial noise by the standard deviation required by the scheduler
112
+ latents = latents * self.scheduler.init_noise_sigma
113
+ return latents
114
+
115
+ # Follows diffusers.VaeImageProcessor.postprocess
116
+ def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"):
117
+ if output_type not in ["pt", "np", "pil"]:
118
+ raise ValueError(
119
+ f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
120
+ )
121
+
122
+ # Equivalent to diffusers.VaeImageProcessor.denormalize
123
+ sample = (sample / 2 + 0.5).clamp(0, 1)
124
+ if output_type == "pt":
125
+ return sample
126
+
127
+ # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
128
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
129
+ if output_type == "np":
130
+ return sample
131
+
132
+ # Output_type must be 'pil'
133
+ sample = self.numpy_to_pil(sample)
134
+ return sample
135
+
136
+ def prepare_class_labels(self, batch_size, device, class_labels=None):
137
+ if self.unet.config.num_class_embeds is not None:
138
+ if isinstance(class_labels, list):
139
+ class_labels = torch.tensor(class_labels, dtype=torch.int)
140
+ elif isinstance(class_labels, int):
141
+ assert batch_size == 1, "Batch size must be 1 if classes is an int"
142
+ class_labels = torch.tensor([class_labels], dtype=torch.int)
143
+ elif class_labels is None:
144
+ # Randomly generate batch_size class labels
145
+ # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
146
+ class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
147
+ class_labels = class_labels.to(device)
148
+ else:
149
+ class_labels = None
150
+ return class_labels
151
+
152
+ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
153
+ if num_inference_steps is None and timesteps is None:
154
+ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
155
+
156
+ if num_inference_steps is not None and timesteps is not None:
157
+ logger.warning(
158
+ f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
159
+ " `timesteps` will be used over `num_inference_steps`."
160
+ )
161
+
162
+ if latents is not None:
163
+ expected_shape = (batch_size, 3, img_size, img_size)
164
+ if latents.shape != expected_shape:
165
+ raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
166
+
167
+ if (callback_steps is None) or (
168
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
169
+ ):
170
+ raise ValueError(
171
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
172
+ f" {type(callback_steps)}."
173
+ )
174
+
175
+ @torch.no_grad()
176
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
177
+ def __call__(
178
+ self,
179
+ batch_size: int = 1,
180
+ class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
181
+ num_inference_steps: int = 1,
182
+ timesteps: List[int] = None,
183
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
184
+ latents: Optional[torch.Tensor] = None,
185
+ output_type: Optional[str] = "pil",
186
+ return_dict: bool = True,
187
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
188
+ callback_steps: int = 1,
189
+ ):
190
+ r"""
191
+ Args:
192
+ batch_size (`int`, *optional*, defaults to 1):
193
+ The number of images to generate.
194
+ class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
195
+ Optional class labels for conditioning class-conditional consistency models. Not used if the model is
196
+ not class-conditional.
197
+ num_inference_steps (`int`, *optional*, defaults to 1):
198
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
199
+ expense of slower inference.
200
+ timesteps (`List[int]`, *optional*):
201
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
202
+ timesteps are used. Must be in descending order.
203
+ generator (`torch.Generator`, *optional*):
204
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
205
+ generation deterministic.
206
+ latents (`torch.Tensor`, *optional*):
207
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
208
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
209
+ tensor is generated by sampling using the supplied random `generator`.
210
+ output_type (`str`, *optional*, defaults to `"pil"`):
211
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
212
+ return_dict (`bool`, *optional*, defaults to `True`):
213
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
214
+ callback (`Callable`, *optional*):
215
+ A function that calls every `callback_steps` steps during inference. The function is called with the
216
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
217
+ callback_steps (`int`, *optional*, defaults to 1):
218
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
219
+ every step.
220
+
221
+ Examples:
222
+
223
+ Returns:
224
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
225
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
226
+ returned where the first element is a list with the generated images.
227
+ """
228
+ # 0. Prepare call parameters
229
+ img_size = self.unet.config.sample_size
230
+ device = self._execution_device
231
+
232
+ # 1. Check inputs
233
+ self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
234
+
235
+ # 2. Prepare image latents
236
+ # Sample image latents x_0 ~ N(0, sigma_0^2 * I)
237
+ sample = self.prepare_latents(
238
+ batch_size=batch_size,
239
+ num_channels=self.unet.config.in_channels,
240
+ height=img_size,
241
+ width=img_size,
242
+ dtype=self.unet.dtype,
243
+ device=device,
244
+ generator=generator,
245
+ latents=latents,
246
+ )
247
+
248
+ # 3. Handle class_labels for class-conditional models
249
+ class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
250
+
251
+ # 4. Prepare timesteps
252
+ if timesteps is not None:
253
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
254
+ timesteps = self.scheduler.timesteps
255
+ num_inference_steps = len(timesteps)
256
+ else:
257
+ self.scheduler.set_timesteps(num_inference_steps)
258
+ timesteps = self.scheduler.timesteps
259
+
260
+ # 5. Denoising loop
261
+ # Multistep sampling: implements Algorithm 1 in the paper
262
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
263
+ for i, t in enumerate(timesteps):
264
+ scaled_sample = self.scheduler.scale_model_input(sample, t)
265
+ model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
266
+
267
+ sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
268
+
269
+ # call the callback, if provided
270
+ progress_bar.update()
271
+ if callback is not None and i % callback_steps == 0:
272
+ callback(i, t, sample)
273
+
274
+ if XLA_AVAILABLE:
275
+ xm.mark_step()
276
+
277
+ # 6. Post-process image sample
278
+ image = self.postprocess_image(sample, output_type=output_type)
279
+
280
+ # Offload all models
281
+ self.maybe_free_model_hooks()
282
+
283
+ if not return_dict:
284
+ return (image,)
285
+
286
+ return ImagePipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
26
+ _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
27
+ _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
28
+ _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
29
+ _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
30
+ _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
31
+ _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
32
+ _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
33
+ _import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
34
+ _import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
35
+ _import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
36
+ try:
37
+ if not (is_transformers_available() and is_flax_available()):
38
+ raise OptionalDependencyNotAvailable()
39
+ except OptionalDependencyNotAvailable:
40
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
41
+
42
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
43
+ else:
44
+ _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
45
+
46
+
47
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
48
+ try:
49
+ if not (is_transformers_available() and is_torch_available()):
50
+ raise OptionalDependencyNotAvailable()
51
+
52
+ except OptionalDependencyNotAvailable:
53
+ from ...utils.dummy_torch_and_transformers_objects import *
54
+ else:
55
+ from .multicontrolnet import MultiControlNetModel
56
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
57
+ from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
58
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
59
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
60
+ from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
61
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
62
+ from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
63
+ from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
64
+ from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
65
+ from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
66
+
67
+ try:
68
+ if not (is_transformers_available() and is_flax_available()):
69
+ raise OptionalDependencyNotAvailable()
70
+ except OptionalDependencyNotAvailable:
71
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
72
+ else:
73
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
74
+
75
+
76
+ else:
77
+ import sys
78
+
79
+ sys.modules[__name__] = _LazyModule(
80
+ __name__,
81
+ globals()["__file__"],
82
+ _import_structure,
83
+ module_spec=__spec__,
84
+ )
85
+ for name, value in _dummy_objects.items():
86
+ setattr(sys.modules[__name__], name, value)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/multicontrolnet.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...models.controlnets.multicontrolnet import MultiControlNetModel
2
+ from ...utils import deprecate, logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class MultiControlNetModel(MultiControlNetModel):
9
+ def __init__(self, *args, **kwargs):
10
+ deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
11
+ deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
12
+ super().__init__(*args, **kwargs)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet.py ADDED
@@ -0,0 +1,1366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
27
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
29
+ from ...models.lora import adjust_lora_scale_text_encoder
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ deprecate,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
41
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
42
+ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
43
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> # !pip install opencv-python transformers accelerate
60
+ >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
61
+ >>> from diffusers.utils import load_image
62
+ >>> import numpy as np
63
+ >>> import torch
64
+
65
+ >>> import cv2
66
+ >>> from PIL import Image
67
+
68
+ >>> # download an image
69
+ >>> image = load_image(
70
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
71
+ ... )
72
+ >>> image = np.array(image)
73
+
74
+ >>> # get canny image
75
+ >>> image = cv2.Canny(image, 100, 200)
76
+ >>> image = image[:, :, None]
77
+ >>> image = np.concatenate([image, image, image], axis=2)
78
+ >>> canny_image = Image.fromarray(image)
79
+
80
+ >>> # load control net and stable diffusion v1-5
81
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
82
+ >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
83
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
84
+ ... )
85
+
86
+ >>> # speed up diffusion process with faster scheduler and memory optimization
87
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
88
+ >>> # remove following line if xformers is not installed
89
+ >>> pipe.enable_xformers_memory_efficient_attention()
90
+
91
+ >>> pipe.enable_model_cpu_offload()
92
+
93
+ >>> # generate image
94
+ >>> generator = torch.manual_seed(0)
95
+ >>> image = pipe(
96
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
97
+ ... ).images[0]
98
+ ```
99
+ """
100
+
101
+
102
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
103
+ def retrieve_timesteps(
104
+ scheduler,
105
+ num_inference_steps: Optional[int] = None,
106
+ device: Optional[Union[str, torch.device]] = None,
107
+ timesteps: Optional[List[int]] = None,
108
+ sigmas: Optional[List[float]] = None,
109
+ **kwargs,
110
+ ):
111
+ r"""
112
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
113
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
114
+
115
+ Args:
116
+ scheduler (`SchedulerMixin`):
117
+ The scheduler to get timesteps from.
118
+ num_inference_steps (`int`):
119
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
120
+ must be `None`.
121
+ device (`str` or `torch.device`, *optional*):
122
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
123
+ timesteps (`List[int]`, *optional*):
124
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
125
+ `num_inference_steps` and `sigmas` must be `None`.
126
+ sigmas (`List[float]`, *optional*):
127
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
128
+ `num_inference_steps` and `timesteps` must be `None`.
129
+
130
+ Returns:
131
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
132
+ second element is the number of inference steps.
133
+ """
134
+ if timesteps is not None and sigmas is not None:
135
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
136
+ if timesteps is not None:
137
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
138
+ if not accepts_timesteps:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" timestep schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ elif sigmas is not None:
147
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
148
+ if not accept_sigmas:
149
+ raise ValueError(
150
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
151
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
152
+ )
153
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
154
+ timesteps = scheduler.timesteps
155
+ num_inference_steps = len(timesteps)
156
+ else:
157
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ return timesteps, num_inference_steps
160
+
161
+
162
+ class StableDiffusionControlNetPipeline(
163
+ DiffusionPipeline,
164
+ StableDiffusionMixin,
165
+ TextualInversionLoaderMixin,
166
+ StableDiffusionLoraLoaderMixin,
167
+ IPAdapterMixin,
168
+ FromSingleFileMixin,
169
+ ):
170
+ r"""
171
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
172
+
173
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
174
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
175
+
176
+ The pipeline also inherits the following loading methods:
177
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
178
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
179
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
180
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
181
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
182
+
183
+ Args:
184
+ vae ([`AutoencoderKL`]):
185
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
186
+ text_encoder ([`~transformers.CLIPTextModel`]):
187
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
188
+ tokenizer ([`~transformers.CLIPTokenizer`]):
189
+ A `CLIPTokenizer` to tokenize text.
190
+ unet ([`UNet2DConditionModel`]):
191
+ A `UNet2DConditionModel` to denoise the encoded image latents.
192
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
193
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
194
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
195
+ additional conditioning.
196
+ scheduler ([`SchedulerMixin`]):
197
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
198
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
199
+ safety_checker ([`StableDiffusionSafetyChecker`]):
200
+ Classification module that estimates whether generated images could be considered offensive or harmful.
201
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
202
+ more details about a model's potential harms.
203
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
204
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
205
+ """
206
+
207
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
208
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
209
+ _exclude_from_cpu_offload = ["safety_checker"]
210
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"]
211
+
212
+ def __init__(
213
+ self,
214
+ vae: AutoencoderKL,
215
+ text_encoder: CLIPTextModel,
216
+ tokenizer: CLIPTokenizer,
217
+ unet: UNet2DConditionModel,
218
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
219
+ scheduler: KarrasDiffusionSchedulers,
220
+ safety_checker: StableDiffusionSafetyChecker,
221
+ feature_extractor: CLIPImageProcessor,
222
+ image_encoder: CLIPVisionModelWithProjection = None,
223
+ requires_safety_checker: bool = True,
224
+ ):
225
+ super().__init__()
226
+
227
+ if safety_checker is None and requires_safety_checker:
228
+ logger.warning(
229
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
230
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
231
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
232
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
233
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
234
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
235
+ )
236
+
237
+ if safety_checker is not None and feature_extractor is None:
238
+ raise ValueError(
239
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
240
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
241
+ )
242
+
243
+ if isinstance(controlnet, (list, tuple)):
244
+ controlnet = MultiControlNetModel(controlnet)
245
+
246
+ self.register_modules(
247
+ vae=vae,
248
+ text_encoder=text_encoder,
249
+ tokenizer=tokenizer,
250
+ unet=unet,
251
+ controlnet=controlnet,
252
+ scheduler=scheduler,
253
+ safety_checker=safety_checker,
254
+ feature_extractor=feature_extractor,
255
+ image_encoder=image_encoder,
256
+ )
257
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
258
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
259
+ self.control_image_processor = VaeImageProcessor(
260
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
261
+ )
262
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
263
+
264
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
265
+ def _encode_prompt(
266
+ self,
267
+ prompt,
268
+ device,
269
+ num_images_per_prompt,
270
+ do_classifier_free_guidance,
271
+ negative_prompt=None,
272
+ prompt_embeds: Optional[torch.Tensor] = None,
273
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
274
+ lora_scale: Optional[float] = None,
275
+ **kwargs,
276
+ ):
277
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
278
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
279
+
280
+ prompt_embeds_tuple = self.encode_prompt(
281
+ prompt=prompt,
282
+ device=device,
283
+ num_images_per_prompt=num_images_per_prompt,
284
+ do_classifier_free_guidance=do_classifier_free_guidance,
285
+ negative_prompt=negative_prompt,
286
+ prompt_embeds=prompt_embeds,
287
+ negative_prompt_embeds=negative_prompt_embeds,
288
+ lora_scale=lora_scale,
289
+ **kwargs,
290
+ )
291
+
292
+ # concatenate for backwards comp
293
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
294
+
295
+ return prompt_embeds
296
+
297
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
298
+ def encode_prompt(
299
+ self,
300
+ prompt,
301
+ device,
302
+ num_images_per_prompt,
303
+ do_classifier_free_guidance,
304
+ negative_prompt=None,
305
+ prompt_embeds: Optional[torch.Tensor] = None,
306
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
307
+ lora_scale: Optional[float] = None,
308
+ clip_skip: Optional[int] = None,
309
+ ):
310
+ r"""
311
+ Encodes the prompt into text encoder hidden states.
312
+
313
+ Args:
314
+ prompt (`str` or `List[str]`, *optional*):
315
+ prompt to be encoded
316
+ device: (`torch.device`):
317
+ torch device
318
+ num_images_per_prompt (`int`):
319
+ number of images that should be generated per prompt
320
+ do_classifier_free_guidance (`bool`):
321
+ whether to use classifier free guidance or not
322
+ negative_prompt (`str` or `List[str]`, *optional*):
323
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
324
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
325
+ less than `1`).
326
+ prompt_embeds (`torch.Tensor`, *optional*):
327
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
328
+ provided, text embeddings will be generated from `prompt` input argument.
329
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
330
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
331
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
332
+ argument.
333
+ lora_scale (`float`, *optional*):
334
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
335
+ clip_skip (`int`, *optional*):
336
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
337
+ the output of the pre-final layer will be used for computing the prompt embeddings.
338
+ """
339
+ # set lora scale so that monkey patched LoRA
340
+ # function of text encoder can correctly access it
341
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
342
+ self._lora_scale = lora_scale
343
+
344
+ # dynamically adjust the LoRA scale
345
+ if not USE_PEFT_BACKEND:
346
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
347
+ else:
348
+ scale_lora_layers(self.text_encoder, lora_scale)
349
+
350
+ if prompt is not None and isinstance(prompt, str):
351
+ batch_size = 1
352
+ elif prompt is not None and isinstance(prompt, list):
353
+ batch_size = len(prompt)
354
+ else:
355
+ batch_size = prompt_embeds.shape[0]
356
+
357
+ if prompt_embeds is None:
358
+ # textual inversion: process multi-vector tokens if necessary
359
+ if isinstance(self, TextualInversionLoaderMixin):
360
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
361
+
362
+ text_inputs = self.tokenizer(
363
+ prompt,
364
+ padding="max_length",
365
+ max_length=self.tokenizer.model_max_length,
366
+ truncation=True,
367
+ return_tensors="pt",
368
+ )
369
+ text_input_ids = text_inputs.input_ids
370
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
371
+
372
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
373
+ text_input_ids, untruncated_ids
374
+ ):
375
+ removed_text = self.tokenizer.batch_decode(
376
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
377
+ )
378
+ logger.warning(
379
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
380
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
381
+ )
382
+
383
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
384
+ attention_mask = text_inputs.attention_mask.to(device)
385
+ else:
386
+ attention_mask = None
387
+
388
+ if clip_skip is None:
389
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
390
+ prompt_embeds = prompt_embeds[0]
391
+ else:
392
+ prompt_embeds = self.text_encoder(
393
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
394
+ )
395
+ # Access the `hidden_states` first, that contains a tuple of
396
+ # all the hidden states from the encoder layers. Then index into
397
+ # the tuple to access the hidden states from the desired layer.
398
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
399
+ # We also need to apply the final LayerNorm here to not mess with the
400
+ # representations. The `last_hidden_states` that we typically use for
401
+ # obtaining the final prompt representations passes through the LayerNorm
402
+ # layer.
403
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
404
+
405
+ if self.text_encoder is not None:
406
+ prompt_embeds_dtype = self.text_encoder.dtype
407
+ elif self.unet is not None:
408
+ prompt_embeds_dtype = self.unet.dtype
409
+ else:
410
+ prompt_embeds_dtype = prompt_embeds.dtype
411
+
412
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
413
+
414
+ bs_embed, seq_len, _ = prompt_embeds.shape
415
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
416
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
417
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
418
+
419
+ # get unconditional embeddings for classifier free guidance
420
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
421
+ uncond_tokens: List[str]
422
+ if negative_prompt is None:
423
+ uncond_tokens = [""] * batch_size
424
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
425
+ raise TypeError(
426
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
427
+ f" {type(prompt)}."
428
+ )
429
+ elif isinstance(negative_prompt, str):
430
+ uncond_tokens = [negative_prompt]
431
+ elif batch_size != len(negative_prompt):
432
+ raise ValueError(
433
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
434
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
435
+ " the batch size of `prompt`."
436
+ )
437
+ else:
438
+ uncond_tokens = negative_prompt
439
+
440
+ # textual inversion: process multi-vector tokens if necessary
441
+ if isinstance(self, TextualInversionLoaderMixin):
442
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
443
+
444
+ max_length = prompt_embeds.shape[1]
445
+ uncond_input = self.tokenizer(
446
+ uncond_tokens,
447
+ padding="max_length",
448
+ max_length=max_length,
449
+ truncation=True,
450
+ return_tensors="pt",
451
+ )
452
+
453
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
454
+ attention_mask = uncond_input.attention_mask.to(device)
455
+ else:
456
+ attention_mask = None
457
+
458
+ negative_prompt_embeds = self.text_encoder(
459
+ uncond_input.input_ids.to(device),
460
+ attention_mask=attention_mask,
461
+ )
462
+ negative_prompt_embeds = negative_prompt_embeds[0]
463
+
464
+ if do_classifier_free_guidance:
465
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
466
+ seq_len = negative_prompt_embeds.shape[1]
467
+
468
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
469
+
470
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
471
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
472
+
473
+ if self.text_encoder is not None:
474
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
475
+ # Retrieve the original scale by scaling back the LoRA layers
476
+ unscale_lora_layers(self.text_encoder, lora_scale)
477
+
478
+ return prompt_embeds, negative_prompt_embeds
479
+
480
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
481
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
482
+ dtype = next(self.image_encoder.parameters()).dtype
483
+
484
+ if not isinstance(image, torch.Tensor):
485
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
486
+
487
+ image = image.to(device=device, dtype=dtype)
488
+ if output_hidden_states:
489
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
490
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
491
+ uncond_image_enc_hidden_states = self.image_encoder(
492
+ torch.zeros_like(image), output_hidden_states=True
493
+ ).hidden_states[-2]
494
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
495
+ num_images_per_prompt, dim=0
496
+ )
497
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
498
+ else:
499
+ image_embeds = self.image_encoder(image).image_embeds
500
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
501
+ uncond_image_embeds = torch.zeros_like(image_embeds)
502
+
503
+ return image_embeds, uncond_image_embeds
504
+
505
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
506
+ def prepare_ip_adapter_image_embeds(
507
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
508
+ ):
509
+ image_embeds = []
510
+ if do_classifier_free_guidance:
511
+ negative_image_embeds = []
512
+ if ip_adapter_image_embeds is None:
513
+ if not isinstance(ip_adapter_image, list):
514
+ ip_adapter_image = [ip_adapter_image]
515
+
516
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
517
+ raise ValueError(
518
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
519
+ )
520
+
521
+ for single_ip_adapter_image, image_proj_layer in zip(
522
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
523
+ ):
524
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
525
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
526
+ single_ip_adapter_image, device, 1, output_hidden_state
527
+ )
528
+
529
+ image_embeds.append(single_image_embeds[None, :])
530
+ if do_classifier_free_guidance:
531
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
532
+ else:
533
+ for single_image_embeds in ip_adapter_image_embeds:
534
+ if do_classifier_free_guidance:
535
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
536
+ negative_image_embeds.append(single_negative_image_embeds)
537
+ image_embeds.append(single_image_embeds)
538
+
539
+ ip_adapter_image_embeds = []
540
+ for i, single_image_embeds in enumerate(image_embeds):
541
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
542
+ if do_classifier_free_guidance:
543
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
544
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
545
+
546
+ single_image_embeds = single_image_embeds.to(device=device)
547
+ ip_adapter_image_embeds.append(single_image_embeds)
548
+
549
+ return ip_adapter_image_embeds
550
+
551
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
552
+ def run_safety_checker(self, image, device, dtype):
553
+ if self.safety_checker is None:
554
+ has_nsfw_concept = None
555
+ else:
556
+ if torch.is_tensor(image):
557
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
558
+ else:
559
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
560
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
561
+ image, has_nsfw_concept = self.safety_checker(
562
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
563
+ )
564
+ return image, has_nsfw_concept
565
+
566
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
567
+ def decode_latents(self, latents):
568
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
569
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
570
+
571
+ latents = 1 / self.vae.config.scaling_factor * latents
572
+ image = self.vae.decode(latents, return_dict=False)[0]
573
+ image = (image / 2 + 0.5).clamp(0, 1)
574
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
575
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
576
+ return image
577
+
578
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
579
+ def prepare_extra_step_kwargs(self, generator, eta):
580
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
581
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
582
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
583
+ # and should be between [0, 1]
584
+
585
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
586
+ extra_step_kwargs = {}
587
+ if accepts_eta:
588
+ extra_step_kwargs["eta"] = eta
589
+
590
+ # check if the scheduler accepts generator
591
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
592
+ if accepts_generator:
593
+ extra_step_kwargs["generator"] = generator
594
+ return extra_step_kwargs
595
+
596
+ def check_inputs(
597
+ self,
598
+ prompt,
599
+ image,
600
+ callback_steps,
601
+ negative_prompt=None,
602
+ prompt_embeds=None,
603
+ negative_prompt_embeds=None,
604
+ ip_adapter_image=None,
605
+ ip_adapter_image_embeds=None,
606
+ controlnet_conditioning_scale=1.0,
607
+ control_guidance_start=0.0,
608
+ control_guidance_end=1.0,
609
+ callback_on_step_end_tensor_inputs=None,
610
+ ):
611
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
612
+ raise ValueError(
613
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
614
+ f" {type(callback_steps)}."
615
+ )
616
+
617
+ if callback_on_step_end_tensor_inputs is not None and not all(
618
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
619
+ ):
620
+ raise ValueError(
621
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
622
+ )
623
+
624
+ if prompt is not None and prompt_embeds is not None:
625
+ raise ValueError(
626
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
627
+ " only forward one of the two."
628
+ )
629
+ elif prompt is None and prompt_embeds is None:
630
+ raise ValueError(
631
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
632
+ )
633
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
634
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
635
+
636
+ if negative_prompt is not None and negative_prompt_embeds is not None:
637
+ raise ValueError(
638
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
639
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
640
+ )
641
+
642
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
643
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
644
+ raise ValueError(
645
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
646
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
647
+ f" {negative_prompt_embeds.shape}."
648
+ )
649
+
650
+ # Check `image`
651
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
652
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
653
+ )
654
+ if (
655
+ isinstance(self.controlnet, ControlNetModel)
656
+ or is_compiled
657
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
658
+ ):
659
+ self.check_image(image, prompt, prompt_embeds)
660
+ elif (
661
+ isinstance(self.controlnet, MultiControlNetModel)
662
+ or is_compiled
663
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
664
+ ):
665
+ if not isinstance(image, list):
666
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
667
+
668
+ # When `image` is a nested list:
669
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
670
+ elif any(isinstance(i, list) for i in image):
671
+ transposed_image = [list(t) for t in zip(*image)]
672
+ if len(transposed_image) != len(self.controlnet.nets):
673
+ raise ValueError(
674
+ f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
675
+ )
676
+ for image_ in transposed_image:
677
+ self.check_image(image_, prompt, prompt_embeds)
678
+ elif len(image) != len(self.controlnet.nets):
679
+ raise ValueError(
680
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
681
+ )
682
+ else:
683
+ for image_ in image:
684
+ self.check_image(image_, prompt, prompt_embeds)
685
+ else:
686
+ assert False
687
+
688
+ # Check `controlnet_conditioning_scale`
689
+ if (
690
+ isinstance(self.controlnet, ControlNetModel)
691
+ or is_compiled
692
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
693
+ ):
694
+ if not isinstance(controlnet_conditioning_scale, float):
695
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
696
+ elif (
697
+ isinstance(self.controlnet, MultiControlNetModel)
698
+ or is_compiled
699
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
700
+ ):
701
+ if isinstance(controlnet_conditioning_scale, list):
702
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
703
+ raise ValueError(
704
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
705
+ "The conditioning scale must be fixed across the batch."
706
+ )
707
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
708
+ self.controlnet.nets
709
+ ):
710
+ raise ValueError(
711
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
712
+ " the same length as the number of controlnets"
713
+ )
714
+ else:
715
+ assert False
716
+
717
+ if not isinstance(control_guidance_start, (tuple, list)):
718
+ control_guidance_start = [control_guidance_start]
719
+
720
+ if not isinstance(control_guidance_end, (tuple, list)):
721
+ control_guidance_end = [control_guidance_end]
722
+
723
+ if len(control_guidance_start) != len(control_guidance_end):
724
+ raise ValueError(
725
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
726
+ )
727
+
728
+ if isinstance(self.controlnet, MultiControlNetModel):
729
+ if len(control_guidance_start) != len(self.controlnet.nets):
730
+ raise ValueError(
731
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
732
+ )
733
+
734
+ for start, end in zip(control_guidance_start, control_guidance_end):
735
+ if start >= end:
736
+ raise ValueError(
737
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
738
+ )
739
+ if start < 0.0:
740
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
741
+ if end > 1.0:
742
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
743
+
744
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
745
+ raise ValueError(
746
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
747
+ )
748
+
749
+ if ip_adapter_image_embeds is not None:
750
+ if not isinstance(ip_adapter_image_embeds, list):
751
+ raise ValueError(
752
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
753
+ )
754
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
755
+ raise ValueError(
756
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
757
+ )
758
+
759
+ def check_image(self, image, prompt, prompt_embeds):
760
+ image_is_pil = isinstance(image, PIL.Image.Image)
761
+ image_is_tensor = isinstance(image, torch.Tensor)
762
+ image_is_np = isinstance(image, np.ndarray)
763
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
764
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
765
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
766
+
767
+ if (
768
+ not image_is_pil
769
+ and not image_is_tensor
770
+ and not image_is_np
771
+ and not image_is_pil_list
772
+ and not image_is_tensor_list
773
+ and not image_is_np_list
774
+ ):
775
+ raise TypeError(
776
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
777
+ )
778
+
779
+ if image_is_pil:
780
+ image_batch_size = 1
781
+ else:
782
+ image_batch_size = len(image)
783
+
784
+ if prompt is not None and isinstance(prompt, str):
785
+ prompt_batch_size = 1
786
+ elif prompt is not None and isinstance(prompt, list):
787
+ prompt_batch_size = len(prompt)
788
+ elif prompt_embeds is not None:
789
+ prompt_batch_size = prompt_embeds.shape[0]
790
+
791
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
792
+ raise ValueError(
793
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
794
+ )
795
+
796
+ def prepare_image(
797
+ self,
798
+ image,
799
+ width,
800
+ height,
801
+ batch_size,
802
+ num_images_per_prompt,
803
+ device,
804
+ dtype,
805
+ do_classifier_free_guidance=False,
806
+ guess_mode=False,
807
+ ):
808
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
809
+ image_batch_size = image.shape[0]
810
+
811
+ if image_batch_size == 1:
812
+ repeat_by = batch_size
813
+ else:
814
+ # image batch size is the same as prompt batch size
815
+ repeat_by = num_images_per_prompt
816
+
817
+ image = image.repeat_interleave(repeat_by, dim=0)
818
+
819
+ image = image.to(device=device, dtype=dtype)
820
+
821
+ if do_classifier_free_guidance and not guess_mode:
822
+ image = torch.cat([image] * 2)
823
+
824
+ return image
825
+
826
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
827
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
828
+ shape = (
829
+ batch_size,
830
+ num_channels_latents,
831
+ int(height) // self.vae_scale_factor,
832
+ int(width) // self.vae_scale_factor,
833
+ )
834
+ if isinstance(generator, list) and len(generator) != batch_size:
835
+ raise ValueError(
836
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
837
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
838
+ )
839
+
840
+ if latents is None:
841
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
842
+ else:
843
+ latents = latents.to(device)
844
+
845
+ # scale the initial noise by the standard deviation required by the scheduler
846
+ latents = latents * self.scheduler.init_noise_sigma
847
+ return latents
848
+
849
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
850
+ def get_guidance_scale_embedding(
851
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
852
+ ) -> torch.Tensor:
853
+ """
854
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
855
+
856
+ Args:
857
+ w (`torch.Tensor`):
858
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
859
+ embedding_dim (`int`, *optional*, defaults to 512):
860
+ Dimension of the embeddings to generate.
861
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
862
+ Data type of the generated embeddings.
863
+
864
+ Returns:
865
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
866
+ """
867
+ assert len(w.shape) == 1
868
+ w = w * 1000.0
869
+
870
+ half_dim = embedding_dim // 2
871
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
872
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
873
+ emb = w.to(dtype)[:, None] * emb[None, :]
874
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
875
+ if embedding_dim % 2 == 1: # zero pad
876
+ emb = torch.nn.functional.pad(emb, (0, 1))
877
+ assert emb.shape == (w.shape[0], embedding_dim)
878
+ return emb
879
+
880
+ @property
881
+ def guidance_scale(self):
882
+ return self._guidance_scale
883
+
884
+ @property
885
+ def clip_skip(self):
886
+ return self._clip_skip
887
+
888
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
889
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
890
+ # corresponds to doing no classifier free guidance.
891
+ @property
892
+ def do_classifier_free_guidance(self):
893
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
894
+
895
+ @property
896
+ def cross_attention_kwargs(self):
897
+ return self._cross_attention_kwargs
898
+
899
+ @property
900
+ def num_timesteps(self):
901
+ return self._num_timesteps
902
+
903
+ @property
904
+ def interrupt(self):
905
+ return self._interrupt
906
+
907
+ @torch.no_grad()
908
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
909
+ def __call__(
910
+ self,
911
+ prompt: Union[str, List[str]] = None,
912
+ image: PipelineImageInput = None,
913
+ height: Optional[int] = None,
914
+ width: Optional[int] = None,
915
+ num_inference_steps: int = 50,
916
+ timesteps: List[int] = None,
917
+ sigmas: List[float] = None,
918
+ guidance_scale: float = 7.5,
919
+ negative_prompt: Optional[Union[str, List[str]]] = None,
920
+ num_images_per_prompt: Optional[int] = 1,
921
+ eta: float = 0.0,
922
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
923
+ latents: Optional[torch.Tensor] = None,
924
+ prompt_embeds: Optional[torch.Tensor] = None,
925
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
926
+ ip_adapter_image: Optional[PipelineImageInput] = None,
927
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
928
+ output_type: Optional[str] = "pil",
929
+ return_dict: bool = True,
930
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
931
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
932
+ guess_mode: bool = False,
933
+ control_guidance_start: Union[float, List[float]] = 0.0,
934
+ control_guidance_end: Union[float, List[float]] = 1.0,
935
+ clip_skip: Optional[int] = None,
936
+ callback_on_step_end: Optional[
937
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
938
+ ] = None,
939
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
940
+ **kwargs,
941
+ ):
942
+ r"""
943
+ The call function to the pipeline for generation.
944
+
945
+ Args:
946
+ prompt (`str` or `List[str]`, *optional*):
947
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
948
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
949
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
950
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
951
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
952
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
953
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
954
+ images must be passed as a list such that each element of the list can be correctly batched for input
955
+ to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
956
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
957
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
958
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
959
+ The height in pixels of the generated image.
960
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
961
+ The width in pixels of the generated image.
962
+ num_inference_steps (`int`, *optional*, defaults to 50):
963
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
964
+ expense of slower inference.
965
+ timesteps (`List[int]`, *optional*):
966
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
967
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
968
+ passed will be used. Must be in descending order.
969
+ sigmas (`List[float]`, *optional*):
970
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
971
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
972
+ will be used.
973
+ guidance_scale (`float`, *optional*, defaults to 7.5):
974
+ A higher guidance scale value encourages the model to generate images closely linked to the text
975
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
976
+ negative_prompt (`str` or `List[str]`, *optional*):
977
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
978
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
979
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
980
+ The number of images to generate per prompt.
981
+ eta (`float`, *optional*, defaults to 0.0):
982
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
983
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
984
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
985
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
986
+ generation deterministic.
987
+ latents (`torch.Tensor`, *optional*):
988
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
989
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
990
+ tensor is generated by sampling using the supplied random `generator`.
991
+ prompt_embeds (`torch.Tensor`, *optional*):
992
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
993
+ provided, text embeddings are generated from the `prompt` input argument.
994
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
995
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
996
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
997
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
998
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
999
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1000
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1001
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1002
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1003
+ output_type (`str`, *optional*, defaults to `"pil"`):
1004
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1005
+ return_dict (`bool`, *optional*, defaults to `True`):
1006
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1007
+ plain tuple.
1008
+ callback (`Callable`, *optional*):
1009
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1010
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1011
+ callback_steps (`int`, *optional*, defaults to 1):
1012
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1013
+ every step.
1014
+ cross_attention_kwargs (`dict`, *optional*):
1015
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1016
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1017
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1018
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1019
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1020
+ the corresponding scale as a list.
1021
+ guess_mode (`bool`, *optional*, defaults to `False`):
1022
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1023
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1024
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1025
+ The percentage of total steps at which the ControlNet starts applying.
1026
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1027
+ The percentage of total steps at which the ControlNet stops applying.
1028
+ clip_skip (`int`, *optional*):
1029
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1030
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1031
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1032
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1033
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1034
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1035
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1036
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1037
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1038
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1039
+ `._callback_tensor_inputs` attribute of your pipeline class.
1040
+
1041
+ Examples:
1042
+
1043
+ Returns:
1044
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1045
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1046
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1047
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1048
+ "not-safe-for-work" (nsfw) content.
1049
+ """
1050
+
1051
+ callback = kwargs.pop("callback", None)
1052
+ callback_steps = kwargs.pop("callback_steps", None)
1053
+
1054
+ if callback is not None:
1055
+ deprecate(
1056
+ "callback",
1057
+ "1.0.0",
1058
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1059
+ )
1060
+ if callback_steps is not None:
1061
+ deprecate(
1062
+ "callback_steps",
1063
+ "1.0.0",
1064
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1065
+ )
1066
+
1067
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1068
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1069
+
1070
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1071
+
1072
+ # align format for control guidance
1073
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1074
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1075
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1076
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1077
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1078
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1079
+ control_guidance_start, control_guidance_end = (
1080
+ mult * [control_guidance_start],
1081
+ mult * [control_guidance_end],
1082
+ )
1083
+
1084
+ # 1. Check inputs. Raise error if not correct
1085
+ self.check_inputs(
1086
+ prompt,
1087
+ image,
1088
+ callback_steps,
1089
+ negative_prompt,
1090
+ prompt_embeds,
1091
+ negative_prompt_embeds,
1092
+ ip_adapter_image,
1093
+ ip_adapter_image_embeds,
1094
+ controlnet_conditioning_scale,
1095
+ control_guidance_start,
1096
+ control_guidance_end,
1097
+ callback_on_step_end_tensor_inputs,
1098
+ )
1099
+
1100
+ self._guidance_scale = guidance_scale
1101
+ self._clip_skip = clip_skip
1102
+ self._cross_attention_kwargs = cross_attention_kwargs
1103
+ self._interrupt = False
1104
+
1105
+ # 2. Define call parameters
1106
+ if prompt is not None and isinstance(prompt, str):
1107
+ batch_size = 1
1108
+ elif prompt is not None and isinstance(prompt, list):
1109
+ batch_size = len(prompt)
1110
+ else:
1111
+ batch_size = prompt_embeds.shape[0]
1112
+
1113
+ device = self._execution_device
1114
+
1115
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1116
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1117
+
1118
+ global_pool_conditions = (
1119
+ controlnet.config.global_pool_conditions
1120
+ if isinstance(controlnet, ControlNetModel)
1121
+ else controlnet.nets[0].config.global_pool_conditions
1122
+ )
1123
+ guess_mode = guess_mode or global_pool_conditions
1124
+
1125
+ # 3. Encode input prompt
1126
+ text_encoder_lora_scale = (
1127
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1128
+ )
1129
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1130
+ prompt,
1131
+ device,
1132
+ num_images_per_prompt,
1133
+ self.do_classifier_free_guidance,
1134
+ negative_prompt,
1135
+ prompt_embeds=prompt_embeds,
1136
+ negative_prompt_embeds=negative_prompt_embeds,
1137
+ lora_scale=text_encoder_lora_scale,
1138
+ clip_skip=self.clip_skip,
1139
+ )
1140
+ # For classifier free guidance, we need to do two forward passes.
1141
+ # Here we concatenate the unconditional and text embeddings into a single batch
1142
+ # to avoid doing two forward passes
1143
+ if self.do_classifier_free_guidance:
1144
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1145
+
1146
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1147
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1148
+ ip_adapter_image,
1149
+ ip_adapter_image_embeds,
1150
+ device,
1151
+ batch_size * num_images_per_prompt,
1152
+ self.do_classifier_free_guidance,
1153
+ )
1154
+
1155
+ # 4. Prepare image
1156
+ if isinstance(controlnet, ControlNetModel):
1157
+ image = self.prepare_image(
1158
+ image=image,
1159
+ width=width,
1160
+ height=height,
1161
+ batch_size=batch_size * num_images_per_prompt,
1162
+ num_images_per_prompt=num_images_per_prompt,
1163
+ device=device,
1164
+ dtype=controlnet.dtype,
1165
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1166
+ guess_mode=guess_mode,
1167
+ )
1168
+ height, width = image.shape[-2:]
1169
+ elif isinstance(controlnet, MultiControlNetModel):
1170
+ images = []
1171
+
1172
+ # Nested lists as ControlNet condition
1173
+ if isinstance(image[0], list):
1174
+ # Transpose the nested image list
1175
+ image = [list(t) for t in zip(*image)]
1176
+
1177
+ for image_ in image:
1178
+ image_ = self.prepare_image(
1179
+ image=image_,
1180
+ width=width,
1181
+ height=height,
1182
+ batch_size=batch_size * num_images_per_prompt,
1183
+ num_images_per_prompt=num_images_per_prompt,
1184
+ device=device,
1185
+ dtype=controlnet.dtype,
1186
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1187
+ guess_mode=guess_mode,
1188
+ )
1189
+
1190
+ images.append(image_)
1191
+
1192
+ image = images
1193
+ height, width = image[0].shape[-2:]
1194
+ else:
1195
+ assert False
1196
+
1197
+ # 5. Prepare timesteps
1198
+ timesteps, num_inference_steps = retrieve_timesteps(
1199
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1200
+ )
1201
+ self._num_timesteps = len(timesteps)
1202
+
1203
+ # 6. Prepare latent variables
1204
+ num_channels_latents = self.unet.config.in_channels
1205
+ latents = self.prepare_latents(
1206
+ batch_size * num_images_per_prompt,
1207
+ num_channels_latents,
1208
+ height,
1209
+ width,
1210
+ prompt_embeds.dtype,
1211
+ device,
1212
+ generator,
1213
+ latents,
1214
+ )
1215
+
1216
+ # 6.5 Optionally get Guidance Scale Embedding
1217
+ timestep_cond = None
1218
+ if self.unet.config.time_cond_proj_dim is not None:
1219
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1220
+ timestep_cond = self.get_guidance_scale_embedding(
1221
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1222
+ ).to(device=device, dtype=latents.dtype)
1223
+
1224
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1225
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1226
+
1227
+ # 7.1 Add image embeds for IP-Adapter
1228
+ added_cond_kwargs = (
1229
+ {"image_embeds": image_embeds}
1230
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1231
+ else None
1232
+ )
1233
+
1234
+ # 7.2 Create tensor stating which controlnets to keep
1235
+ controlnet_keep = []
1236
+ for i in range(len(timesteps)):
1237
+ keeps = [
1238
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1239
+ for s, e in zip(control_guidance_start, control_guidance_end)
1240
+ ]
1241
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1242
+
1243
+ # 8. Denoising loop
1244
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1245
+ is_unet_compiled = is_compiled_module(self.unet)
1246
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1247
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1248
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1249
+ for i, t in enumerate(timesteps):
1250
+ if self.interrupt:
1251
+ continue
1252
+
1253
+ # Relevant thread:
1254
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1255
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1256
+ torch._inductor.cudagraph_mark_step_begin()
1257
+ # expand the latents if we are doing classifier free guidance
1258
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1259
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1260
+
1261
+ # controlnet(s) inference
1262
+ if guess_mode and self.do_classifier_free_guidance:
1263
+ # Infer ControlNet only for the conditional batch.
1264
+ control_model_input = latents
1265
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1266
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1267
+ else:
1268
+ control_model_input = latent_model_input
1269
+ controlnet_prompt_embeds = prompt_embeds
1270
+
1271
+ if isinstance(controlnet_keep[i], list):
1272
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1273
+ else:
1274
+ controlnet_cond_scale = controlnet_conditioning_scale
1275
+ if isinstance(controlnet_cond_scale, list):
1276
+ controlnet_cond_scale = controlnet_cond_scale[0]
1277
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1278
+
1279
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1280
+ control_model_input,
1281
+ t,
1282
+ encoder_hidden_states=controlnet_prompt_embeds,
1283
+ controlnet_cond=image,
1284
+ conditioning_scale=cond_scale,
1285
+ guess_mode=guess_mode,
1286
+ return_dict=False,
1287
+ )
1288
+
1289
+ if guess_mode and self.do_classifier_free_guidance:
1290
+ # Inferred ControlNet only for the conditional batch.
1291
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1292
+ # add 0 to the unconditional batch to keep it unchanged.
1293
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1294
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1295
+
1296
+ # predict the noise residual
1297
+ noise_pred = self.unet(
1298
+ latent_model_input,
1299
+ t,
1300
+ encoder_hidden_states=prompt_embeds,
1301
+ timestep_cond=timestep_cond,
1302
+ cross_attention_kwargs=self.cross_attention_kwargs,
1303
+ down_block_additional_residuals=down_block_res_samples,
1304
+ mid_block_additional_residual=mid_block_res_sample,
1305
+ added_cond_kwargs=added_cond_kwargs,
1306
+ return_dict=False,
1307
+ )[0]
1308
+
1309
+ # perform guidance
1310
+ if self.do_classifier_free_guidance:
1311
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1312
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1313
+
1314
+ # compute the previous noisy sample x_t -> x_t-1
1315
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1316
+
1317
+ if callback_on_step_end is not None:
1318
+ callback_kwargs = {}
1319
+ for k in callback_on_step_end_tensor_inputs:
1320
+ callback_kwargs[k] = locals()[k]
1321
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1322
+
1323
+ latents = callback_outputs.pop("latents", latents)
1324
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1325
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1326
+ image = callback_outputs.pop("image", image)
1327
+
1328
+ # call the callback, if provided
1329
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1330
+ progress_bar.update()
1331
+ if callback is not None and i % callback_steps == 0:
1332
+ step_idx = i // getattr(self.scheduler, "order", 1)
1333
+ callback(step_idx, t, latents)
1334
+
1335
+ if XLA_AVAILABLE:
1336
+ xm.mark_step()
1337
+ # If we do sequential model offloading, let's offload unet and controlnet
1338
+ # manually for max memory savings
1339
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1340
+ self.unet.to("cpu")
1341
+ self.controlnet.to("cpu")
1342
+ empty_device_cache()
1343
+
1344
+ if not output_type == "latent":
1345
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1346
+ 0
1347
+ ]
1348
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1349
+ else:
1350
+ image = latents
1351
+ has_nsfw_concept = None
1352
+
1353
+ if has_nsfw_concept is None:
1354
+ do_denormalize = [True] * image.shape[0]
1355
+ else:
1356
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1357
+
1358
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1359
+
1360
+ # Offload all models
1361
+ self.maybe_free_model_hooks()
1362
+
1363
+ if not return_dict:
1364
+ return (image, has_nsfw_concept)
1365
+
1366
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Salesforce.com, inc.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Union
16
+
17
+ import PIL.Image
18
+ import torch
19
+ from transformers import CLIPTokenizer
20
+
21
+ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
22
+ from ...schedulers import PNDMScheduler
23
+ from ...utils import (
24
+ is_torch_xla_available,
25
+ logging,
26
+ replace_example_docstring,
27
+ )
28
+ from ...utils.torch_utils import randn_tensor
29
+ from ..blip_diffusion.blip_image_processing import BlipImageProcessor
30
+ from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
31
+ from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
32
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
33
+
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ EXAMPLE_DOC_STRING = """
46
+ Examples:
47
+ ```py
48
+ >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
49
+ >>> from diffusers.utils import load_image
50
+ >>> from controlnet_aux import CannyDetector
51
+ >>> import torch
52
+
53
+ >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
54
+ ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
55
+ ... ).to("cuda")
56
+
57
+ >>> style_subject = "flower"
58
+ >>> tgt_subject = "teapot"
59
+ >>> text_prompt = "on a marble table"
60
+
61
+ >>> cldm_cond_image = load_image(
62
+ ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
63
+ ... ).resize((512, 512))
64
+ >>> canny = CannyDetector()
65
+ >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
66
+ >>> style_image = load_image(
67
+ ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
68
+ ... )
69
+ >>> guidance_scale = 7.5
70
+ >>> num_inference_steps = 50
71
+ >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
72
+
73
+
74
+ >>> output = blip_diffusion_pipe(
75
+ ... text_prompt,
76
+ ... style_image,
77
+ ... cldm_cond_image,
78
+ ... style_subject,
79
+ ... tgt_subject,
80
+ ... guidance_scale=guidance_scale,
81
+ ... num_inference_steps=num_inference_steps,
82
+ ... neg_prompt=negative_prompt,
83
+ ... height=512,
84
+ ... width=512,
85
+ ... ).images
86
+ >>> output[0].save("image.png")
87
+ ```
88
+ """
89
+
90
+
91
+ class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
92
+ """
93
+ Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
94
+
95
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
96
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
97
+
98
+ Args:
99
+ tokenizer ([`CLIPTokenizer`]):
100
+ Tokenizer for the text encoder
101
+ text_encoder ([`ContextCLIPTextModel`]):
102
+ Text encoder to encode the text prompt
103
+ vae ([`AutoencoderKL`]):
104
+ VAE model to map the latents to the image
105
+ unet ([`UNet2DConditionModel`]):
106
+ Conditional U-Net architecture to denoise the image embedding.
107
+ scheduler ([`PNDMScheduler`]):
108
+ A scheduler to be used in combination with `unet` to generate image latents.
109
+ qformer ([`Blip2QFormerModel`]):
110
+ QFormer model to get multi-modal embeddings from the text and image.
111
+ controlnet ([`ControlNetModel`]):
112
+ ControlNet model to get the conditioning image embedding.
113
+ image_processor ([`BlipImageProcessor`]):
114
+ Image Processor to preprocess and postprocess the image.
115
+ ctx_begin_pos (int, `optional`, defaults to 2):
116
+ Position of the context token in the text encoder.
117
+ """
118
+
119
+ _last_supported_version = "0.33.1"
120
+ model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
121
+
122
+ def __init__(
123
+ self,
124
+ tokenizer: CLIPTokenizer,
125
+ text_encoder: ContextCLIPTextModel,
126
+ vae: AutoencoderKL,
127
+ unet: UNet2DConditionModel,
128
+ scheduler: PNDMScheduler,
129
+ qformer: Blip2QFormerModel,
130
+ controlnet: ControlNetModel,
131
+ image_processor: BlipImageProcessor,
132
+ ctx_begin_pos: int = 2,
133
+ mean: List[float] = None,
134
+ std: List[float] = None,
135
+ ):
136
+ super().__init__()
137
+
138
+ self.register_modules(
139
+ tokenizer=tokenizer,
140
+ text_encoder=text_encoder,
141
+ vae=vae,
142
+ unet=unet,
143
+ scheduler=scheduler,
144
+ qformer=qformer,
145
+ controlnet=controlnet,
146
+ image_processor=image_processor,
147
+ )
148
+ self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
149
+
150
+ def get_query_embeddings(self, input_image, src_subject):
151
+ return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
152
+
153
+ # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
154
+ def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
155
+ rv = []
156
+ for prompt, tgt_subject in zip(prompts, tgt_subjects):
157
+ prompt = f"a {tgt_subject} {prompt.strip()}"
158
+ # a trick to amplify the prompt
159
+ rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
160
+
161
+ return rv
162
+
163
+ # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
164
+ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
165
+ shape = (batch_size, num_channels, height, width)
166
+ if isinstance(generator, list) and len(generator) != batch_size:
167
+ raise ValueError(
168
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
169
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
170
+ )
171
+
172
+ if latents is None:
173
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
174
+ else:
175
+ latents = latents.to(device=device, dtype=dtype)
176
+
177
+ # scale the initial noise by the standard deviation required by the scheduler
178
+ latents = latents * self.scheduler.init_noise_sigma
179
+ return latents
180
+
181
+ def encode_prompt(self, query_embeds, prompt, device=None):
182
+ device = device or self._execution_device
183
+
184
+ # embeddings for prompt, with query_embeds as context
185
+ max_len = self.text_encoder.text_model.config.max_position_embeddings
186
+ max_len -= self.qformer.config.num_query_tokens
187
+
188
+ tokenized_prompt = self.tokenizer(
189
+ prompt,
190
+ padding="max_length",
191
+ truncation=True,
192
+ max_length=max_len,
193
+ return_tensors="pt",
194
+ ).to(device)
195
+
196
+ batch_size = query_embeds.shape[0]
197
+ ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
198
+
199
+ text_embeddings = self.text_encoder(
200
+ input_ids=tokenized_prompt.input_ids,
201
+ ctx_embeddings=query_embeds,
202
+ ctx_begin_pos=ctx_begin_pos,
203
+ )[0]
204
+
205
+ return text_embeddings
206
+
207
+ # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
208
+ def prepare_control_image(
209
+ self,
210
+ image,
211
+ width,
212
+ height,
213
+ batch_size,
214
+ num_images_per_prompt,
215
+ device,
216
+ dtype,
217
+ do_classifier_free_guidance=False,
218
+ ):
219
+ image = self.image_processor.preprocess(
220
+ image,
221
+ size={"width": width, "height": height},
222
+ do_rescale=True,
223
+ do_center_crop=False,
224
+ do_normalize=False,
225
+ return_tensors="pt",
226
+ )["pixel_values"].to(device)
227
+ image_batch_size = image.shape[0]
228
+
229
+ if image_batch_size == 1:
230
+ repeat_by = batch_size
231
+ else:
232
+ # image batch size is the same as prompt batch size
233
+ repeat_by = num_images_per_prompt
234
+
235
+ image = image.repeat_interleave(repeat_by, dim=0)
236
+
237
+ image = image.to(device=device, dtype=dtype)
238
+
239
+ if do_classifier_free_guidance:
240
+ image = torch.cat([image] * 2)
241
+
242
+ return image
243
+
244
+ @torch.no_grad()
245
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
246
+ def __call__(
247
+ self,
248
+ prompt: List[str],
249
+ reference_image: PIL.Image.Image,
250
+ condtioning_image: PIL.Image.Image,
251
+ source_subject_category: List[str],
252
+ target_subject_category: List[str],
253
+ latents: Optional[torch.Tensor] = None,
254
+ guidance_scale: float = 7.5,
255
+ height: int = 512,
256
+ width: int = 512,
257
+ num_inference_steps: int = 50,
258
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
259
+ neg_prompt: Optional[str] = "",
260
+ prompt_strength: float = 1.0,
261
+ prompt_reps: int = 20,
262
+ output_type: Optional[str] = "pil",
263
+ return_dict: bool = True,
264
+ ):
265
+ """
266
+ Function invoked when calling the pipeline for generation.
267
+
268
+ Args:
269
+ prompt (`List[str]`):
270
+ The prompt or prompts to guide the image generation.
271
+ reference_image (`PIL.Image.Image`):
272
+ The reference image to condition the generation on.
273
+ condtioning_image (`PIL.Image.Image`):
274
+ The conditioning canny edge image to condition the generation on.
275
+ source_subject_category (`List[str]`):
276
+ The source subject category.
277
+ target_subject_category (`List[str]`):
278
+ The target subject category.
279
+ latents (`torch.Tensor`, *optional*):
280
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
281
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
282
+ tensor will be generated by random sampling.
283
+ guidance_scale (`float`, *optional*, defaults to 7.5):
284
+ Guidance scale as defined in [Classifier-Free Diffusion
285
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
286
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
287
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
288
+ the text `prompt`, usually at the expense of lower image quality.
289
+ height (`int`, *optional*, defaults to 512):
290
+ The height of the generated image.
291
+ width (`int`, *optional*, defaults to 512):
292
+ The width of the generated image.
293
+ seed (`int`, *optional*, defaults to 42):
294
+ The seed to use for random generation.
295
+ num_inference_steps (`int`, *optional*, defaults to 50):
296
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
297
+ expense of slower inference.
298
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
299
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
300
+ to make generation deterministic.
301
+ neg_prompt (`str`, *optional*, defaults to ""):
302
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
303
+ if `guidance_scale` is less than `1`).
304
+ prompt_strength (`float`, *optional*, defaults to 1.0):
305
+ The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
306
+ to amplify the prompt.
307
+ prompt_reps (`int`, *optional*, defaults to 20):
308
+ The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
309
+ Examples:
310
+
311
+ Returns:
312
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
313
+ """
314
+ device = self._execution_device
315
+
316
+ reference_image = self.image_processor.preprocess(
317
+ reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
318
+ )["pixel_values"]
319
+ reference_image = reference_image.to(device)
320
+
321
+ if isinstance(prompt, str):
322
+ prompt = [prompt]
323
+ if isinstance(source_subject_category, str):
324
+ source_subject_category = [source_subject_category]
325
+ if isinstance(target_subject_category, str):
326
+ target_subject_category = [target_subject_category]
327
+
328
+ batch_size = len(prompt)
329
+
330
+ prompt = self._build_prompt(
331
+ prompts=prompt,
332
+ tgt_subjects=target_subject_category,
333
+ prompt_strength=prompt_strength,
334
+ prompt_reps=prompt_reps,
335
+ )
336
+ query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
337
+ text_embeddings = self.encode_prompt(query_embeds, prompt, device)
338
+ # 3. unconditional embedding
339
+ do_classifier_free_guidance = guidance_scale > 1.0
340
+ if do_classifier_free_guidance:
341
+ max_length = self.text_encoder.text_model.config.max_position_embeddings
342
+
343
+ uncond_input = self.tokenizer(
344
+ [neg_prompt] * batch_size,
345
+ padding="max_length",
346
+ max_length=max_length,
347
+ return_tensors="pt",
348
+ )
349
+ uncond_embeddings = self.text_encoder(
350
+ input_ids=uncond_input.input_ids.to(device),
351
+ ctx_embeddings=None,
352
+ )[0]
353
+ # For classifier free guidance, we need to do two forward passes.
354
+ # Here we concatenate the unconditional and text embeddings into a single batch
355
+ # to avoid doing two forward passes
356
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
357
+ scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
358
+ latents = self.prepare_latents(
359
+ batch_size=batch_size,
360
+ num_channels=self.unet.config.in_channels,
361
+ height=height // scale_down_factor,
362
+ width=width // scale_down_factor,
363
+ generator=generator,
364
+ latents=latents,
365
+ dtype=self.unet.dtype,
366
+ device=device,
367
+ )
368
+ # set timesteps
369
+ extra_set_kwargs = {}
370
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
371
+
372
+ cond_image = self.prepare_control_image(
373
+ image=condtioning_image,
374
+ width=width,
375
+ height=height,
376
+ batch_size=batch_size,
377
+ num_images_per_prompt=1,
378
+ device=device,
379
+ dtype=self.controlnet.dtype,
380
+ do_classifier_free_guidance=do_classifier_free_guidance,
381
+ )
382
+
383
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
384
+ # expand the latents if we are doing classifier free guidance
385
+ do_classifier_free_guidance = guidance_scale > 1.0
386
+
387
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
388
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
389
+ latent_model_input,
390
+ t,
391
+ encoder_hidden_states=text_embeddings,
392
+ controlnet_cond=cond_image,
393
+ return_dict=False,
394
+ )
395
+
396
+ noise_pred = self.unet(
397
+ latent_model_input,
398
+ timestep=t,
399
+ encoder_hidden_states=text_embeddings,
400
+ down_block_additional_residuals=down_block_res_samples,
401
+ mid_block_additional_residual=mid_block_res_sample,
402
+ )["sample"]
403
+
404
+ # perform guidance
405
+ if do_classifier_free_guidance:
406
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
407
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
408
+
409
+ latents = self.scheduler.step(
410
+ noise_pred,
411
+ t,
412
+ latents,
413
+ )["prev_sample"]
414
+
415
+ if XLA_AVAILABLE:
416
+ xm.mark_step()
417
+
418
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
419
+ image = self.image_processor.postprocess(image, output_type=output_type)
420
+
421
+ # Offload all models
422
+ self.maybe_free_model_hooks()
423
+
424
+ if not return_dict:
425
+ return (image,)
426
+
427
+ return ImagePipelineOutput(images=image)
pythonProject/diffusers-main/build/lib/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
26
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
28
+ from ...models.lora import adjust_lora_scale_text_encoder
29
+ from ...schedulers import KarrasDiffusionSchedulers
30
+ from ...utils import (
31
+ USE_PEFT_BACKEND,
32
+ deprecate,
33
+ is_torch_xla_available,
34
+ logging,
35
+ replace_example_docstring,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
40
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
+ from ..stable_diffusion import StableDiffusionPipelineOutput
42
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
+
44
+
45
+ if is_torch_xla_available():
46
+ import torch_xla.core.xla_model as xm
47
+
48
+ XLA_AVAILABLE = True
49
+ else:
50
+ XLA_AVAILABLE = False
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> # !pip install opencv-python transformers accelerate
59
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
60
+ >>> from diffusers.utils import load_image
61
+ >>> import numpy as np
62
+ >>> import torch
63
+
64
+ >>> import cv2
65
+ >>> from PIL import Image
66
+
67
+ >>> # download an image
68
+ >>> image = load_image(
69
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
70
+ ... )
71
+ >>> np_image = np.array(image)
72
+
73
+ >>> # get canny image
74
+ >>> np_image = cv2.Canny(np_image, 100, 200)
75
+ >>> np_image = np_image[:, :, None]
76
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
77
+ >>> canny_image = Image.fromarray(np_image)
78
+
79
+ >>> # load control net and stable diffusion v1-5
80
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
81
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
82
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
83
+ ... )
84
+
85
+ >>> # speed up diffusion process with faster scheduler and memory optimization
86
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
87
+ >>> pipe.enable_model_cpu_offload()
88
+
89
+ >>> # generate image
90
+ >>> generator = torch.manual_seed(0)
91
+ >>> image = pipe(
92
+ ... "futuristic-looking woman",
93
+ ... num_inference_steps=20,
94
+ ... generator=generator,
95
+ ... image=image,
96
+ ... control_image=canny_image,
97
+ ... ).images[0]
98
+ ```
99
+ """
100
+
101
+
102
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
103
+ def retrieve_latents(
104
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
105
+ ):
106
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
107
+ return encoder_output.latent_dist.sample(generator)
108
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
109
+ return encoder_output.latent_dist.mode()
110
+ elif hasattr(encoder_output, "latents"):
111
+ return encoder_output.latents
112
+ else:
113
+ raise AttributeError("Could not access latents of provided encoder_output")
114
+
115
+
116
+ def prepare_image(image):
117
+ if isinstance(image, torch.Tensor):
118
+ # Batch single image
119
+ if image.ndim == 3:
120
+ image = image.unsqueeze(0)
121
+
122
+ image = image.to(dtype=torch.float32)
123
+ else:
124
+ # preprocess image
125
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
126
+ image = [image]
127
+
128
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
129
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
130
+ image = np.concatenate(image, axis=0)
131
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
132
+ image = np.concatenate([i[None, :] for i in image], axis=0)
133
+
134
+ image = image.transpose(0, 3, 1, 2)
135
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
136
+
137
+ return image
138
+
139
+
140
+ class StableDiffusionControlNetImg2ImgPipeline(
141
+ DiffusionPipeline,
142
+ StableDiffusionMixin,
143
+ TextualInversionLoaderMixin,
144
+ StableDiffusionLoraLoaderMixin,
145
+ IPAdapterMixin,
146
+ FromSingleFileMixin,
147
+ ):
148
+ r"""
149
+ Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
150
+
151
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
152
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
153
+
154
+ The pipeline also inherits the following loading methods:
155
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
156
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
157
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
158
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
159
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
160
+
161
+ Args:
162
+ vae ([`AutoencoderKL`]):
163
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
164
+ text_encoder ([`~transformers.CLIPTextModel`]):
165
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
166
+ tokenizer ([`~transformers.CLIPTokenizer`]):
167
+ A `CLIPTokenizer` to tokenize text.
168
+ unet ([`UNet2DConditionModel`]):
169
+ A `UNet2DConditionModel` to denoise the encoded image latents.
170
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
171
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
172
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
173
+ additional conditioning.
174
+ scheduler ([`SchedulerMixin`]):
175
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
176
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
177
+ safety_checker ([`StableDiffusionSafetyChecker`]):
178
+ Classification module that estimates whether generated images could be considered offensive or harmful.
179
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
180
+ more details about a model's potential harms.
181
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
182
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
183
+ """
184
+
185
+ model_cpu_offload_seq = "text_encoder->unet->vae"
186
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
187
+ _exclude_from_cpu_offload = ["safety_checker"]
188
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
189
+
190
+ def __init__(
191
+ self,
192
+ vae: AutoencoderKL,
193
+ text_encoder: CLIPTextModel,
194
+ tokenizer: CLIPTokenizer,
195
+ unet: UNet2DConditionModel,
196
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
197
+ scheduler: KarrasDiffusionSchedulers,
198
+ safety_checker: StableDiffusionSafetyChecker,
199
+ feature_extractor: CLIPImageProcessor,
200
+ image_encoder: CLIPVisionModelWithProjection = None,
201
+ requires_safety_checker: bool = True,
202
+ ):
203
+ super().__init__()
204
+
205
+ if safety_checker is None and requires_safety_checker:
206
+ logger.warning(
207
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
208
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
209
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
210
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
211
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
212
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
213
+ )
214
+
215
+ if safety_checker is not None and feature_extractor is None:
216
+ raise ValueError(
217
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
218
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
219
+ )
220
+
221
+ if isinstance(controlnet, (list, tuple)):
222
+ controlnet = MultiControlNetModel(controlnet)
223
+
224
+ self.register_modules(
225
+ vae=vae,
226
+ text_encoder=text_encoder,
227
+ tokenizer=tokenizer,
228
+ unet=unet,
229
+ controlnet=controlnet,
230
+ scheduler=scheduler,
231
+ safety_checker=safety_checker,
232
+ feature_extractor=feature_extractor,
233
+ image_encoder=image_encoder,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
237
+ self.control_image_processor = VaeImageProcessor(
238
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
239
+ )
240
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
241
+
242
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
243
+ def _encode_prompt(
244
+ self,
245
+ prompt,
246
+ device,
247
+ num_images_per_prompt,
248
+ do_classifier_free_guidance,
249
+ negative_prompt=None,
250
+ prompt_embeds: Optional[torch.Tensor] = None,
251
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
252
+ lora_scale: Optional[float] = None,
253
+ **kwargs,
254
+ ):
255
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
256
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
257
+
258
+ prompt_embeds_tuple = self.encode_prompt(
259
+ prompt=prompt,
260
+ device=device,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ do_classifier_free_guidance=do_classifier_free_guidance,
263
+ negative_prompt=negative_prompt,
264
+ prompt_embeds=prompt_embeds,
265
+ negative_prompt_embeds=negative_prompt_embeds,
266
+ lora_scale=lora_scale,
267
+ **kwargs,
268
+ )
269
+
270
+ # concatenate for backwards comp
271
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
272
+
273
+ return prompt_embeds
274
+
275
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
276
+ def encode_prompt(
277
+ self,
278
+ prompt,
279
+ device,
280
+ num_images_per_prompt,
281
+ do_classifier_free_guidance,
282
+ negative_prompt=None,
283
+ prompt_embeds: Optional[torch.Tensor] = None,
284
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
285
+ lora_scale: Optional[float] = None,
286
+ clip_skip: Optional[int] = None,
287
+ ):
288
+ r"""
289
+ Encodes the prompt into text encoder hidden states.
290
+
291
+ Args:
292
+ prompt (`str` or `List[str]`, *optional*):
293
+ prompt to be encoded
294
+ device: (`torch.device`):
295
+ torch device
296
+ num_images_per_prompt (`int`):
297
+ number of images that should be generated per prompt
298
+ do_classifier_free_guidance (`bool`):
299
+ whether to use classifier free guidance or not
300
+ negative_prompt (`str` or `List[str]`, *optional*):
301
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
302
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
303
+ less than `1`).
304
+ prompt_embeds (`torch.Tensor`, *optional*):
305
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
306
+ provided, text embeddings will be generated from `prompt` input argument.
307
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
308
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
309
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
310
+ argument.
311
+ lora_scale (`float`, *optional*):
312
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
313
+ clip_skip (`int`, *optional*):
314
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
315
+ the output of the pre-final layer will be used for computing the prompt embeddings.
316
+ """
317
+ # set lora scale so that monkey patched LoRA
318
+ # function of text encoder can correctly access it
319
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
320
+ self._lora_scale = lora_scale
321
+
322
+ # dynamically adjust the LoRA scale
323
+ if not USE_PEFT_BACKEND:
324
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
325
+ else:
326
+ scale_lora_layers(self.text_encoder, lora_scale)
327
+
328
+ if prompt is not None and isinstance(prompt, str):
329
+ batch_size = 1
330
+ elif prompt is not None and isinstance(prompt, list):
331
+ batch_size = len(prompt)
332
+ else:
333
+ batch_size = prompt_embeds.shape[0]
334
+
335
+ if prompt_embeds is None:
336
+ # textual inversion: process multi-vector tokens if necessary
337
+ if isinstance(self, TextualInversionLoaderMixin):
338
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
339
+
340
+ text_inputs = self.tokenizer(
341
+ prompt,
342
+ padding="max_length",
343
+ max_length=self.tokenizer.model_max_length,
344
+ truncation=True,
345
+ return_tensors="pt",
346
+ )
347
+ text_input_ids = text_inputs.input_ids
348
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
349
+
350
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
351
+ text_input_ids, untruncated_ids
352
+ ):
353
+ removed_text = self.tokenizer.batch_decode(
354
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
355
+ )
356
+ logger.warning(
357
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
358
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
359
+ )
360
+
361
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
362
+ attention_mask = text_inputs.attention_mask.to(device)
363
+ else:
364
+ attention_mask = None
365
+
366
+ if clip_skip is None:
367
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
368
+ prompt_embeds = prompt_embeds[0]
369
+ else:
370
+ prompt_embeds = self.text_encoder(
371
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
372
+ )
373
+ # Access the `hidden_states` first, that contains a tuple of
374
+ # all the hidden states from the encoder layers. Then index into
375
+ # the tuple to access the hidden states from the desired layer.
376
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
377
+ # We also need to apply the final LayerNorm here to not mess with the
378
+ # representations. The `last_hidden_states` that we typically use for
379
+ # obtaining the final prompt representations passes through the LayerNorm
380
+ # layer.
381
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
382
+
383
+ if self.text_encoder is not None:
384
+ prompt_embeds_dtype = self.text_encoder.dtype
385
+ elif self.unet is not None:
386
+ prompt_embeds_dtype = self.unet.dtype
387
+ else:
388
+ prompt_embeds_dtype = prompt_embeds.dtype
389
+
390
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
391
+
392
+ bs_embed, seq_len, _ = prompt_embeds.shape
393
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
394
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
395
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
396
+
397
+ # get unconditional embeddings for classifier free guidance
398
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
399
+ uncond_tokens: List[str]
400
+ if negative_prompt is None:
401
+ uncond_tokens = [""] * batch_size
402
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
403
+ raise TypeError(
404
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
405
+ f" {type(prompt)}."
406
+ )
407
+ elif isinstance(negative_prompt, str):
408
+ uncond_tokens = [negative_prompt]
409
+ elif batch_size != len(negative_prompt):
410
+ raise ValueError(
411
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
412
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
413
+ " the batch size of `prompt`."
414
+ )
415
+ else:
416
+ uncond_tokens = negative_prompt
417
+
418
+ # textual inversion: process multi-vector tokens if necessary
419
+ if isinstance(self, TextualInversionLoaderMixin):
420
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
421
+
422
+ max_length = prompt_embeds.shape[1]
423
+ uncond_input = self.tokenizer(
424
+ uncond_tokens,
425
+ padding="max_length",
426
+ max_length=max_length,
427
+ truncation=True,
428
+ return_tensors="pt",
429
+ )
430
+
431
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
432
+ attention_mask = uncond_input.attention_mask.to(device)
433
+ else:
434
+ attention_mask = None
435
+
436
+ negative_prompt_embeds = self.text_encoder(
437
+ uncond_input.input_ids.to(device),
438
+ attention_mask=attention_mask,
439
+ )
440
+ negative_prompt_embeds = negative_prompt_embeds[0]
441
+
442
+ if do_classifier_free_guidance:
443
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
444
+ seq_len = negative_prompt_embeds.shape[1]
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
447
+
448
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
449
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
450
+
451
+ if self.text_encoder is not None:
452
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
453
+ # Retrieve the original scale by scaling back the LoRA layers
454
+ unscale_lora_layers(self.text_encoder, lora_scale)
455
+
456
+ return prompt_embeds, negative_prompt_embeds
457
+
458
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
459
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
460
+ dtype = next(self.image_encoder.parameters()).dtype
461
+
462
+ if not isinstance(image, torch.Tensor):
463
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
464
+
465
+ image = image.to(device=device, dtype=dtype)
466
+ if output_hidden_states:
467
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
468
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
469
+ uncond_image_enc_hidden_states = self.image_encoder(
470
+ torch.zeros_like(image), output_hidden_states=True
471
+ ).hidden_states[-2]
472
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
473
+ num_images_per_prompt, dim=0
474
+ )
475
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
476
+ else:
477
+ image_embeds = self.image_encoder(image).image_embeds
478
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
479
+ uncond_image_embeds = torch.zeros_like(image_embeds)
480
+
481
+ return image_embeds, uncond_image_embeds
482
+
483
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
484
+ def prepare_ip_adapter_image_embeds(
485
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
486
+ ):
487
+ image_embeds = []
488
+ if do_classifier_free_guidance:
489
+ negative_image_embeds = []
490
+ if ip_adapter_image_embeds is None:
491
+ if not isinstance(ip_adapter_image, list):
492
+ ip_adapter_image = [ip_adapter_image]
493
+
494
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
495
+ raise ValueError(
496
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
497
+ )
498
+
499
+ for single_ip_adapter_image, image_proj_layer in zip(
500
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
501
+ ):
502
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
503
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
504
+ single_ip_adapter_image, device, 1, output_hidden_state
505
+ )
506
+
507
+ image_embeds.append(single_image_embeds[None, :])
508
+ if do_classifier_free_guidance:
509
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
510
+ else:
511
+ for single_image_embeds in ip_adapter_image_embeds:
512
+ if do_classifier_free_guidance:
513
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
514
+ negative_image_embeds.append(single_negative_image_embeds)
515
+ image_embeds.append(single_image_embeds)
516
+
517
+ ip_adapter_image_embeds = []
518
+ for i, single_image_embeds in enumerate(image_embeds):
519
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
520
+ if do_classifier_free_guidance:
521
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
522
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
523
+
524
+ single_image_embeds = single_image_embeds.to(device=device)
525
+ ip_adapter_image_embeds.append(single_image_embeds)
526
+
527
+ return ip_adapter_image_embeds
528
+
529
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
530
+ def run_safety_checker(self, image, device, dtype):
531
+ if self.safety_checker is None:
532
+ has_nsfw_concept = None
533
+ else:
534
+ if torch.is_tensor(image):
535
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
536
+ else:
537
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
538
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
539
+ image, has_nsfw_concept = self.safety_checker(
540
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
541
+ )
542
+ return image, has_nsfw_concept
543
+
544
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
545
+ def decode_latents(self, latents):
546
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
547
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
548
+
549
+ latents = 1 / self.vae.config.scaling_factor * latents
550
+ image = self.vae.decode(latents, return_dict=False)[0]
551
+ image = (image / 2 + 0.5).clamp(0, 1)
552
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
553
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
554
+ return image
555
+
556
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
557
+ def prepare_extra_step_kwargs(self, generator, eta):
558
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
559
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
560
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
561
+ # and should be between [0, 1]
562
+
563
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
564
+ extra_step_kwargs = {}
565
+ if accepts_eta:
566
+ extra_step_kwargs["eta"] = eta
567
+
568
+ # check if the scheduler accepts generator
569
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
570
+ if accepts_generator:
571
+ extra_step_kwargs["generator"] = generator
572
+ return extra_step_kwargs
573
+
574
+ def check_inputs(
575
+ self,
576
+ prompt,
577
+ image,
578
+ callback_steps,
579
+ negative_prompt=None,
580
+ prompt_embeds=None,
581
+ negative_prompt_embeds=None,
582
+ ip_adapter_image=None,
583
+ ip_adapter_image_embeds=None,
584
+ controlnet_conditioning_scale=1.0,
585
+ control_guidance_start=0.0,
586
+ control_guidance_end=1.0,
587
+ callback_on_step_end_tensor_inputs=None,
588
+ ):
589
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
590
+ raise ValueError(
591
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
592
+ f" {type(callback_steps)}."
593
+ )
594
+
595
+ if callback_on_step_end_tensor_inputs is not None and not all(
596
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
597
+ ):
598
+ raise ValueError(
599
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
600
+ )
601
+
602
+ if prompt is not None and prompt_embeds is not None:
603
+ raise ValueError(
604
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
605
+ " only forward one of the two."
606
+ )
607
+ elif prompt is None and prompt_embeds is None:
608
+ raise ValueError(
609
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
610
+ )
611
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
612
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
613
+
614
+ if negative_prompt is not None and negative_prompt_embeds is not None:
615
+ raise ValueError(
616
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
617
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
618
+ )
619
+
620
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
621
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
622
+ raise ValueError(
623
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
624
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
625
+ f" {negative_prompt_embeds.shape}."
626
+ )
627
+
628
+ # `prompt` needs more sophisticated handling when there are multiple
629
+ # conditionings.
630
+ if isinstance(self.controlnet, MultiControlNetModel):
631
+ if isinstance(prompt, list):
632
+ logger.warning(
633
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
634
+ " prompts. The conditionings will be fixed across the prompts."
635
+ )
636
+
637
+ # Check `image`
638
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
639
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
640
+ )
641
+ if (
642
+ isinstance(self.controlnet, ControlNetModel)
643
+ or is_compiled
644
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
645
+ ):
646
+ self.check_image(image, prompt, prompt_embeds)
647
+ elif (
648
+ isinstance(self.controlnet, MultiControlNetModel)
649
+ or is_compiled
650
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
651
+ ):
652
+ if not isinstance(image, list):
653
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
654
+
655
+ # When `image` is a nested list:
656
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
657
+ elif any(isinstance(i, list) for i in image):
658
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
659
+ elif len(image) != len(self.controlnet.nets):
660
+ raise ValueError(
661
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
662
+ )
663
+
664
+ for image_ in image:
665
+ self.check_image(image_, prompt, prompt_embeds)
666
+ else:
667
+ assert False
668
+
669
+ # Check `controlnet_conditioning_scale`
670
+ if (
671
+ isinstance(self.controlnet, ControlNetModel)
672
+ or is_compiled
673
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
674
+ ):
675
+ if not isinstance(controlnet_conditioning_scale, float):
676
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
677
+ elif (
678
+ isinstance(self.controlnet, MultiControlNetModel)
679
+ or is_compiled
680
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
681
+ ):
682
+ if isinstance(controlnet_conditioning_scale, list):
683
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
684
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
685
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
686
+ self.controlnet.nets
687
+ ):
688
+ raise ValueError(
689
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
690
+ " the same length as the number of controlnets"
691
+ )
692
+ else:
693
+ assert False
694
+
695
+ if len(control_guidance_start) != len(control_guidance_end):
696
+ raise ValueError(
697
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
698
+ )
699
+
700
+ if isinstance(self.controlnet, MultiControlNetModel):
701
+ if len(control_guidance_start) != len(self.controlnet.nets):
702
+ raise ValueError(
703
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
704
+ )
705
+
706
+ for start, end in zip(control_guidance_start, control_guidance_end):
707
+ if start >= end:
708
+ raise ValueError(
709
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
710
+ )
711
+ if start < 0.0:
712
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
713
+ if end > 1.0:
714
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
715
+
716
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
717
+ raise ValueError(
718
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
719
+ )
720
+
721
+ if ip_adapter_image_embeds is not None:
722
+ if not isinstance(ip_adapter_image_embeds, list):
723
+ raise ValueError(
724
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
725
+ )
726
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
727
+ raise ValueError(
728
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
729
+ )
730
+
731
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
732
+ def check_image(self, image, prompt, prompt_embeds):
733
+ image_is_pil = isinstance(image, PIL.Image.Image)
734
+ image_is_tensor = isinstance(image, torch.Tensor)
735
+ image_is_np = isinstance(image, np.ndarray)
736
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
737
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
738
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
739
+
740
+ if (
741
+ not image_is_pil
742
+ and not image_is_tensor
743
+ and not image_is_np
744
+ and not image_is_pil_list
745
+ and not image_is_tensor_list
746
+ and not image_is_np_list
747
+ ):
748
+ raise TypeError(
749
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
750
+ )
751
+
752
+ if image_is_pil:
753
+ image_batch_size = 1
754
+ else:
755
+ image_batch_size = len(image)
756
+
757
+ if prompt is not None and isinstance(prompt, str):
758
+ prompt_batch_size = 1
759
+ elif prompt is not None and isinstance(prompt, list):
760
+ prompt_batch_size = len(prompt)
761
+ elif prompt_embeds is not None:
762
+ prompt_batch_size = prompt_embeds.shape[0]
763
+
764
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
765
+ raise ValueError(
766
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
767
+ )
768
+
769
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
770
+ def prepare_control_image(
771
+ self,
772
+ image,
773
+ width,
774
+ height,
775
+ batch_size,
776
+ num_images_per_prompt,
777
+ device,
778
+ dtype,
779
+ do_classifier_free_guidance=False,
780
+ guess_mode=False,
781
+ ):
782
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
783
+ image_batch_size = image.shape[0]
784
+
785
+ if image_batch_size == 1:
786
+ repeat_by = batch_size
787
+ else:
788
+ # image batch size is the same as prompt batch size
789
+ repeat_by = num_images_per_prompt
790
+
791
+ image = image.repeat_interleave(repeat_by, dim=0)
792
+
793
+ image = image.to(device=device, dtype=dtype)
794
+
795
+ if do_classifier_free_guidance and not guess_mode:
796
+ image = torch.cat([image] * 2)
797
+
798
+ return image
799
+
800
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
801
+ def get_timesteps(self, num_inference_steps, strength, device):
802
+ # get the original timestep using init_timestep
803
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
804
+
805
+ t_start = max(num_inference_steps - init_timestep, 0)
806
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
807
+ if hasattr(self.scheduler, "set_begin_index"):
808
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
809
+
810
+ return timesteps, num_inference_steps - t_start
811
+
812
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
813
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
814
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
815
+ raise ValueError(
816
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
817
+ )
818
+
819
+ image = image.to(device=device, dtype=dtype)
820
+
821
+ batch_size = batch_size * num_images_per_prompt
822
+
823
+ if image.shape[1] == 4:
824
+ init_latents = image
825
+
826
+ else:
827
+ if isinstance(generator, list) and len(generator) != batch_size:
828
+ raise ValueError(
829
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
830
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
831
+ )
832
+
833
+ elif isinstance(generator, list):
834
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
835
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
836
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
837
+ raise ValueError(
838
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
839
+ )
840
+
841
+ init_latents = [
842
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
843
+ for i in range(batch_size)
844
+ ]
845
+ init_latents = torch.cat(init_latents, dim=0)
846
+ else:
847
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
848
+
849
+ init_latents = self.vae.config.scaling_factor * init_latents
850
+
851
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
852
+ # expand init_latents for batch_size
853
+ deprecation_message = (
854
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
855
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
856
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
857
+ " your script to pass as many initial images as text prompts to suppress this warning."
858
+ )
859
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
860
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
861
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
862
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
863
+ raise ValueError(
864
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
865
+ )
866
+ else:
867
+ init_latents = torch.cat([init_latents], dim=0)
868
+
869
+ shape = init_latents.shape
870
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
871
+
872
+ # get latents
873
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
874
+ latents = init_latents
875
+
876
+ return latents
877
+
878
+ @property
879
+ def guidance_scale(self):
880
+ return self._guidance_scale
881
+
882
+ @property
883
+ def clip_skip(self):
884
+ return self._clip_skip
885
+
886
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
887
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
888
+ # corresponds to doing no classifier free guidance.
889
+ @property
890
+ def do_classifier_free_guidance(self):
891
+ return self._guidance_scale > 1
892
+
893
+ @property
894
+ def cross_attention_kwargs(self):
895
+ return self._cross_attention_kwargs
896
+
897
+ @property
898
+ def num_timesteps(self):
899
+ return self._num_timesteps
900
+
901
+ @property
902
+ def interrupt(self):
903
+ return self._interrupt
904
+
905
+ @torch.no_grad()
906
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
907
+ def __call__(
908
+ self,
909
+ prompt: Union[str, List[str]] = None,
910
+ image: PipelineImageInput = None,
911
+ control_image: PipelineImageInput = None,
912
+ height: Optional[int] = None,
913
+ width: Optional[int] = None,
914
+ strength: float = 0.8,
915
+ num_inference_steps: int = 50,
916
+ guidance_scale: float = 7.5,
917
+ negative_prompt: Optional[Union[str, List[str]]] = None,
918
+ num_images_per_prompt: Optional[int] = 1,
919
+ eta: float = 0.0,
920
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
921
+ latents: Optional[torch.Tensor] = None,
922
+ prompt_embeds: Optional[torch.Tensor] = None,
923
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
924
+ ip_adapter_image: Optional[PipelineImageInput] = None,
925
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
926
+ output_type: Optional[str] = "pil",
927
+ return_dict: bool = True,
928
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
929
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
930
+ guess_mode: bool = False,
931
+ control_guidance_start: Union[float, List[float]] = 0.0,
932
+ control_guidance_end: Union[float, List[float]] = 1.0,
933
+ clip_skip: Optional[int] = None,
934
+ callback_on_step_end: Optional[
935
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
936
+ ] = None,
937
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
938
+ **kwargs,
939
+ ):
940
+ r"""
941
+ The call function to the pipeline for generation.
942
+
943
+ Args:
944
+ prompt (`str` or `List[str]`, *optional*):
945
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
946
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
947
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
948
+ The initial image to be used as the starting point for the image generation process. Can also accept
949
+ image latents as `image`, and if passing latents directly they are not encoded again.
950
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
951
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
952
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
953
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
954
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
955
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
956
+ images must be passed as a list such that each element of the list can be correctly batched for input
957
+ to a single ControlNet.
958
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
959
+ The height in pixels of the generated image.
960
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
961
+ The width in pixels of the generated image.
962
+ strength (`float`, *optional*, defaults to 0.8):
963
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
964
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
965
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
966
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
967
+ essentially ignores `image`.
968
+ num_inference_steps (`int`, *optional*, defaults to 50):
969
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
970
+ expense of slower inference.
971
+ guidance_scale (`float`, *optional*, defaults to 7.5):
972
+ A higher guidance scale value encourages the model to generate images closely linked to the text
973
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
974
+ negative_prompt (`str` or `List[str]`, *optional*):
975
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
976
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
977
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
978
+ The number of images to generate per prompt.
979
+ eta (`float`, *optional*, defaults to 0.0):
980
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
981
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
982
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
983
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
984
+ generation deterministic.
985
+ latents (`torch.Tensor`, *optional*):
986
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
987
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
988
+ tensor is generated by sampling using the supplied random `generator`.
989
+ prompt_embeds (`torch.Tensor`, *optional*):
990
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
991
+ provided, text embeddings are generated from the `prompt` input argument.
992
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
993
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
994
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
995
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
996
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
997
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
998
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
999
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1000
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1001
+ output_type (`str`, *optional*, defaults to `"pil"`):
1002
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1003
+ return_dict (`bool`, *optional*, defaults to `True`):
1004
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1005
+ plain tuple.
1006
+ cross_attention_kwargs (`dict`, *optional*):
1007
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1008
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1009
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1010
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1011
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1012
+ the corresponding scale as a list.
1013
+ guess_mode (`bool`, *optional*, defaults to `False`):
1014
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1015
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1016
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1017
+ The percentage of total steps at which the ControlNet starts applying.
1018
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1019
+ The percentage of total steps at which the ControlNet stops applying.
1020
+ clip_skip (`int`, *optional*):
1021
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1022
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1023
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1024
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1025
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1026
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1027
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1028
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1029
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1030
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1031
+ `._callback_tensor_inputs` attribute of your pipeline class.
1032
+
1033
+ Examples:
1034
+
1035
+ Returns:
1036
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1037
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1038
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1039
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1040
+ "not-safe-for-work" (nsfw) content.
1041
+ """
1042
+
1043
+ callback = kwargs.pop("callback", None)
1044
+ callback_steps = kwargs.pop("callback_steps", None)
1045
+
1046
+ if callback is not None:
1047
+ deprecate(
1048
+ "callback",
1049
+ "1.0.0",
1050
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1051
+ )
1052
+ if callback_steps is not None:
1053
+ deprecate(
1054
+ "callback_steps",
1055
+ "1.0.0",
1056
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1057
+ )
1058
+
1059
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1060
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1061
+
1062
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1063
+
1064
+ # align format for control guidance
1065
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1066
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1067
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1068
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1069
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1070
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1071
+ control_guidance_start, control_guidance_end = (
1072
+ mult * [control_guidance_start],
1073
+ mult * [control_guidance_end],
1074
+ )
1075
+
1076
+ # 1. Check inputs. Raise error if not correct
1077
+ self.check_inputs(
1078
+ prompt,
1079
+ control_image,
1080
+ callback_steps,
1081
+ negative_prompt,
1082
+ prompt_embeds,
1083
+ negative_prompt_embeds,
1084
+ ip_adapter_image,
1085
+ ip_adapter_image_embeds,
1086
+ controlnet_conditioning_scale,
1087
+ control_guidance_start,
1088
+ control_guidance_end,
1089
+ callback_on_step_end_tensor_inputs,
1090
+ )
1091
+
1092
+ self._guidance_scale = guidance_scale
1093
+ self._clip_skip = clip_skip
1094
+ self._cross_attention_kwargs = cross_attention_kwargs
1095
+ self._interrupt = False
1096
+
1097
+ # 2. Define call parameters
1098
+ if prompt is not None and isinstance(prompt, str):
1099
+ batch_size = 1
1100
+ elif prompt is not None and isinstance(prompt, list):
1101
+ batch_size = len(prompt)
1102
+ else:
1103
+ batch_size = prompt_embeds.shape[0]
1104
+
1105
+ device = self._execution_device
1106
+
1107
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1108
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1109
+
1110
+ global_pool_conditions = (
1111
+ controlnet.config.global_pool_conditions
1112
+ if isinstance(controlnet, ControlNetModel)
1113
+ else controlnet.nets[0].config.global_pool_conditions
1114
+ )
1115
+ guess_mode = guess_mode or global_pool_conditions
1116
+
1117
+ # 3. Encode input prompt
1118
+ text_encoder_lora_scale = (
1119
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1120
+ )
1121
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1122
+ prompt,
1123
+ device,
1124
+ num_images_per_prompt,
1125
+ self.do_classifier_free_guidance,
1126
+ negative_prompt,
1127
+ prompt_embeds=prompt_embeds,
1128
+ negative_prompt_embeds=negative_prompt_embeds,
1129
+ lora_scale=text_encoder_lora_scale,
1130
+ clip_skip=self.clip_skip,
1131
+ )
1132
+ # For classifier free guidance, we need to do two forward passes.
1133
+ # Here we concatenate the unconditional and text embeddings into a single batch
1134
+ # to avoid doing two forward passes
1135
+ if self.do_classifier_free_guidance:
1136
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1137
+
1138
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1139
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1140
+ ip_adapter_image,
1141
+ ip_adapter_image_embeds,
1142
+ device,
1143
+ batch_size * num_images_per_prompt,
1144
+ self.do_classifier_free_guidance,
1145
+ )
1146
+
1147
+ # 4. Prepare image
1148
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
1149
+
1150
+ # 5. Prepare controlnet_conditioning_image
1151
+ if isinstance(controlnet, ControlNetModel):
1152
+ control_image = self.prepare_control_image(
1153
+ image=control_image,
1154
+ width=width,
1155
+ height=height,
1156
+ batch_size=batch_size * num_images_per_prompt,
1157
+ num_images_per_prompt=num_images_per_prompt,
1158
+ device=device,
1159
+ dtype=controlnet.dtype,
1160
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1161
+ guess_mode=guess_mode,
1162
+ )
1163
+ elif isinstance(controlnet, MultiControlNetModel):
1164
+ control_images = []
1165
+
1166
+ for control_image_ in control_image:
1167
+ control_image_ = self.prepare_control_image(
1168
+ image=control_image_,
1169
+ width=width,
1170
+ height=height,
1171
+ batch_size=batch_size * num_images_per_prompt,
1172
+ num_images_per_prompt=num_images_per_prompt,
1173
+ device=device,
1174
+ dtype=controlnet.dtype,
1175
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1176
+ guess_mode=guess_mode,
1177
+ )
1178
+
1179
+ control_images.append(control_image_)
1180
+
1181
+ control_image = control_images
1182
+ else:
1183
+ assert False
1184
+
1185
+ # 5. Prepare timesteps
1186
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1187
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1188
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1189
+ self._num_timesteps = len(timesteps)
1190
+
1191
+ # 6. Prepare latent variables
1192
+ if latents is None:
1193
+ latents = self.prepare_latents(
1194
+ image,
1195
+ latent_timestep,
1196
+ batch_size,
1197
+ num_images_per_prompt,
1198
+ prompt_embeds.dtype,
1199
+ device,
1200
+ generator,
1201
+ )
1202
+
1203
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1204
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1205
+
1206
+ # 7.1 Add image embeds for IP-Adapter
1207
+ added_cond_kwargs = (
1208
+ {"image_embeds": image_embeds}
1209
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1210
+ else None
1211
+ )
1212
+
1213
+ # 7.2 Create tensor stating which controlnets to keep
1214
+ controlnet_keep = []
1215
+ for i in range(len(timesteps)):
1216
+ keeps = [
1217
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1218
+ for s, e in zip(control_guidance_start, control_guidance_end)
1219
+ ]
1220
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1221
+
1222
+ # 8. Denoising loop
1223
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1224
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1225
+ for i, t in enumerate(timesteps):
1226
+ if self.interrupt:
1227
+ continue
1228
+
1229
+ # expand the latents if we are doing classifier free guidance
1230
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1231
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1232
+
1233
+ # controlnet(s) inference
1234
+ if guess_mode and self.do_classifier_free_guidance:
1235
+ # Infer ControlNet only for the conditional batch.
1236
+ control_model_input = latents
1237
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1238
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1239
+ else:
1240
+ control_model_input = latent_model_input
1241
+ controlnet_prompt_embeds = prompt_embeds
1242
+
1243
+ if isinstance(controlnet_keep[i], list):
1244
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1245
+ else:
1246
+ controlnet_cond_scale = controlnet_conditioning_scale
1247
+ if isinstance(controlnet_cond_scale, list):
1248
+ controlnet_cond_scale = controlnet_cond_scale[0]
1249
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1250
+
1251
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1252
+ control_model_input,
1253
+ t,
1254
+ encoder_hidden_states=controlnet_prompt_embeds,
1255
+ controlnet_cond=control_image,
1256
+ conditioning_scale=cond_scale,
1257
+ guess_mode=guess_mode,
1258
+ return_dict=False,
1259
+ )
1260
+
1261
+ if guess_mode and self.do_classifier_free_guidance:
1262
+ # Inferred ControlNet only for the conditional batch.
1263
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1264
+ # add 0 to the unconditional batch to keep it unchanged.
1265
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1266
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1267
+
1268
+ # predict the noise residual
1269
+ noise_pred = self.unet(
1270
+ latent_model_input,
1271
+ t,
1272
+ encoder_hidden_states=prompt_embeds,
1273
+ cross_attention_kwargs=self.cross_attention_kwargs,
1274
+ down_block_additional_residuals=down_block_res_samples,
1275
+ mid_block_additional_residual=mid_block_res_sample,
1276
+ added_cond_kwargs=added_cond_kwargs,
1277
+ return_dict=False,
1278
+ )[0]
1279
+
1280
+ # perform guidance
1281
+ if self.do_classifier_free_guidance:
1282
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1283
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1284
+
1285
+ # compute the previous noisy sample x_t -> x_t-1
1286
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1287
+
1288
+ if callback_on_step_end is not None:
1289
+ callback_kwargs = {}
1290
+ for k in callback_on_step_end_tensor_inputs:
1291
+ callback_kwargs[k] = locals()[k]
1292
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1293
+
1294
+ latents = callback_outputs.pop("latents", latents)
1295
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1296
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1297
+ control_image = callback_outputs.pop("control_image", control_image)
1298
+
1299
+ # call the callback, if provided
1300
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1301
+ progress_bar.update()
1302
+ if callback is not None and i % callback_steps == 0:
1303
+ step_idx = i // getattr(self.scheduler, "order", 1)
1304
+ callback(step_idx, t, latents)
1305
+
1306
+ if XLA_AVAILABLE:
1307
+ xm.mark_step()
1308
+
1309
+ # If we do sequential model offloading, let's offload unet and controlnet
1310
+ # manually for max memory savings
1311
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1312
+ self.unet.to("cpu")
1313
+ self.controlnet.to("cpu")
1314
+ empty_device_cache()
1315
+
1316
+ if not output_type == "latent":
1317
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1318
+ 0
1319
+ ]
1320
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1321
+ else:
1322
+ image = latents
1323
+ has_nsfw_concept = None
1324
+
1325
+ if has_nsfw_concept is None:
1326
+ do_denormalize = [True] * image.shape[0]
1327
+ else:
1328
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1329
+
1330
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1331
+
1332
+ # Offload all models
1333
+ self.maybe_free_model_hooks()
1334
+
1335
+ if not return_dict:
1336
+ return (image, has_nsfw_concept)
1337
+
1338
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)