xiaoanyu123 commited on
Commit
b4e634b
·
verified ·
1 Parent(s): f45df05

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/__init__.cpython-310.pyc +0 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__init__.py +48 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/__init__.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/pipeline_mochi.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/pipeline_output.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/pipeline_mochi.py +745 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/pipeline_output.py +20 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__init__.py +49 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__pycache__/__init__.cpython-310.pyc +0 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__pycache__/pipeline_musicldm.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/pipeline_musicldm.py +653 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__init__.py +50 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/__init__.cpython-310.pyc +0 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/pipeline_omnigen.cpython-310.pyc +0 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/processor_omnigen.cpython-310.pyc +0 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/pipeline_omnigen.py +514 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/processor_omnigen.py +332 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pag_utils.py +243 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1343 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1554 -0
  21. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1631 -0
  22. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  23. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-310.pyc +0 -0
  24. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-310.pyc +0 -0
  25. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-310.pyc +0 -0
  26. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_upscale.cpython-310.pyc +0 -0
  27. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_output.cpython-310.pyc +0 -0
  28. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc +0 -0
  29. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc +0 -0
  30. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc +0 -0
  31. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc +0 -0
  32. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc +0 -0
  33. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc +0 -0
  34. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc +0 -0
  35. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc +0 -0
  36. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc +0 -0
  37. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc +0 -0
  38. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc +0 -0
  39. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker_flax.cpython-310.pyc +0 -0
  40. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc +0 -0
  41. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  42. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/__init__.cpython-310.pyc +0 -0
  43. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_output.cpython-310.pyc +0 -0
  44. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3.cpython-310.pyc +0 -0
  45. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3_img2img.cpython-310.pyc +0 -0
  46. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3_inpaint.cpython-310.pyc +0 -0
  47. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  48. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  49. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1154 -0
  50. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1379 -0
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_mochi"] = ["MochiPipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_mochi import MochiPipeline
36
+
37
+ else:
38
+ import sys
39
+
40
+ sys.modules[__name__] = _LazyModule(
41
+ __name__,
42
+ globals()["__file__"],
43
+ _import_structure,
44
+ module_spec=__spec__,
45
+ )
46
+
47
+ for name, value in _dummy_objects.items():
48
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/pipeline_mochi.cpython-310.pyc ADDED
Binary file (25.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/__pycache__/pipeline_output.cpython-310.pyc ADDED
Binary file (978 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/pipeline_mochi.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Genmo and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import T5EncoderModel, T5TokenizerFast
21
+
22
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23
+ from ...loaders import Mochi1LoraLoaderMixin
24
+ from ...models import AutoencoderKLMochi, MochiTransformer3DModel
25
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
26
+ from ...utils import (
27
+ is_torch_xla_available,
28
+ logging,
29
+ replace_example_docstring,
30
+ )
31
+ from ...utils.torch_utils import randn_tensor
32
+ from ...video_processor import VideoProcessor
33
+ from ..pipeline_utils import DiffusionPipeline
34
+ from .pipeline_output import MochiPipelineOutput
35
+
36
+
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ EXAMPLE_DOC_STRING = """
48
+ Examples:
49
+ ```py
50
+ >>> import torch
51
+ >>> from diffusers import MochiPipeline
52
+ >>> from diffusers.utils import export_to_video
53
+
54
+ >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
55
+ >>> pipe.enable_model_cpu_offload()
56
+ >>> pipe.enable_vae_tiling()
57
+ >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
58
+ >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
59
+ >>> export_to_video(frames, "mochi.mp4")
60
+ ```
61
+ """
62
+
63
+
64
+ # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
65
+ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
66
+ if linear_steps is None:
67
+ linear_steps = num_steps // 2
68
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
69
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
70
+ quadratic_steps = num_steps - linear_steps
71
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
72
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
73
+ const = quadratic_coef * (linear_steps**2)
74
+ quadratic_sigma_schedule = [
75
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
76
+ ]
77
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
78
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
79
+ return sigma_schedule
80
+
81
+
82
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
83
+ def retrieve_timesteps(
84
+ scheduler,
85
+ num_inference_steps: Optional[int] = None,
86
+ device: Optional[Union[str, torch.device]] = None,
87
+ timesteps: Optional[List[int]] = None,
88
+ sigmas: Optional[List[float]] = None,
89
+ **kwargs,
90
+ ):
91
+ r"""
92
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
93
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
94
+
95
+ Args:
96
+ scheduler (`SchedulerMixin`):
97
+ The scheduler to get timesteps from.
98
+ num_inference_steps (`int`):
99
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
100
+ must be `None`.
101
+ device (`str` or `torch.device`, *optional*):
102
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
103
+ timesteps (`List[int]`, *optional*):
104
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
105
+ `num_inference_steps` and `sigmas` must be `None`.
106
+ sigmas (`List[float]`, *optional*):
107
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
108
+ `num_inference_steps` and `timesteps` must be `None`.
109
+
110
+ Returns:
111
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
112
+ second element is the number of inference steps.
113
+ """
114
+ if timesteps is not None and sigmas is not None:
115
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
116
+ if timesteps is not None:
117
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
118
+ if not accepts_timesteps:
119
+ raise ValueError(
120
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
121
+ f" timestep schedules. Please check whether you are using the correct scheduler."
122
+ )
123
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ num_inference_steps = len(timesteps)
126
+ elif sigmas is not None:
127
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
128
+ if not accept_sigmas:
129
+ raise ValueError(
130
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
131
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
132
+ )
133
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ num_inference_steps = len(timesteps)
136
+ else:
137
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ return timesteps, num_inference_steps
140
+
141
+
142
+ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
143
+ r"""
144
+ The mochi pipeline for text-to-video generation.
145
+
146
+ Reference: https://github.com/genmoai/models
147
+
148
+ Args:
149
+ transformer ([`MochiTransformer3DModel`]):
150
+ Conditional Transformer architecture to denoise the encoded video latents.
151
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
152
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
153
+ vae ([`AutoencoderKLMochi`]):
154
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
155
+ text_encoder ([`T5EncoderModel`]):
156
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
157
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
158
+ tokenizer (`CLIPTokenizer`):
159
+ Tokenizer of class
160
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
161
+ tokenizer (`T5TokenizerFast`):
162
+ Second Tokenizer of class
163
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
164
+ """
165
+
166
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
167
+ _optional_components = []
168
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
169
+
170
+ def __init__(
171
+ self,
172
+ scheduler: FlowMatchEulerDiscreteScheduler,
173
+ vae: AutoencoderKLMochi,
174
+ text_encoder: T5EncoderModel,
175
+ tokenizer: T5TokenizerFast,
176
+ transformer: MochiTransformer3DModel,
177
+ force_zeros_for_empty_prompt: bool = False,
178
+ ):
179
+ super().__init__()
180
+
181
+ self.register_modules(
182
+ vae=vae,
183
+ text_encoder=text_encoder,
184
+ tokenizer=tokenizer,
185
+ transformer=transformer,
186
+ scheduler=scheduler,
187
+ )
188
+ # TODO: determine these scaling factors from model parameters
189
+ self.vae_spatial_scale_factor = 8
190
+ self.vae_temporal_scale_factor = 6
191
+ self.patch_size = 2
192
+
193
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
194
+ self.tokenizer_max_length = (
195
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
196
+ )
197
+ self.default_height = 480
198
+ self.default_width = 848
199
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
200
+
201
+ def _get_t5_prompt_embeds(
202
+ self,
203
+ prompt: Union[str, List[str]] = None,
204
+ num_videos_per_prompt: int = 1,
205
+ max_sequence_length: int = 256,
206
+ device: Optional[torch.device] = None,
207
+ dtype: Optional[torch.dtype] = None,
208
+ ):
209
+ device = device or self._execution_device
210
+ dtype = dtype or self.text_encoder.dtype
211
+
212
+ prompt = [prompt] if isinstance(prompt, str) else prompt
213
+ batch_size = len(prompt)
214
+
215
+ text_inputs = self.tokenizer(
216
+ prompt,
217
+ padding="max_length",
218
+ max_length=max_sequence_length,
219
+ truncation=True,
220
+ add_special_tokens=True,
221
+ return_tensors="pt",
222
+ )
223
+
224
+ text_input_ids = text_inputs.input_ids
225
+ prompt_attention_mask = text_inputs.attention_mask
226
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
227
+
228
+ # The original Mochi implementation zeros out empty negative prompts
229
+ # but this can lead to overflow when placing the entire pipeline under the autocast context
230
+ # adding this here so that we can enable zeroing prompts if necessary
231
+ if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
232
+ text_input_ids = torch.zeros_like(text_input_ids, device=device)
233
+ prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
234
+
235
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
236
+
237
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
238
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
239
+ logger.warning(
240
+ "The following part of your input was truncated because `max_sequence_length` is set to "
241
+ f" {max_sequence_length} tokens: {removed_text}"
242
+ )
243
+
244
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
245
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
246
+
247
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
248
+ _, seq_len, _ = prompt_embeds.shape
249
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
250
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
251
+
252
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
253
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
254
+
255
+ return prompt_embeds, prompt_attention_mask
256
+
257
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
258
+ def encode_prompt(
259
+ self,
260
+ prompt: Union[str, List[str]],
261
+ negative_prompt: Optional[Union[str, List[str]]] = None,
262
+ do_classifier_free_guidance: bool = True,
263
+ num_videos_per_prompt: int = 1,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ prompt_attention_mask: Optional[torch.Tensor] = None,
267
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
268
+ max_sequence_length: int = 256,
269
+ device: Optional[torch.device] = None,
270
+ dtype: Optional[torch.dtype] = None,
271
+ ):
272
+ r"""
273
+ Encodes the prompt into text encoder hidden states.
274
+
275
+ Args:
276
+ prompt (`str` or `List[str]`, *optional*):
277
+ prompt to be encoded
278
+ negative_prompt (`str` or `List[str]`, *optional*):
279
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
280
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
281
+ less than `1`).
282
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
283
+ Whether to use classifier free guidance or not.
284
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
285
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
286
+ prompt_embeds (`torch.Tensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ device: (`torch.device`, *optional*):
294
+ torch device
295
+ dtype: (`torch.dtype`, *optional*):
296
+ torch dtype
297
+ """
298
+ device = device or self._execution_device
299
+
300
+ prompt = [prompt] if isinstance(prompt, str) else prompt
301
+ if prompt is not None:
302
+ batch_size = len(prompt)
303
+ else:
304
+ batch_size = prompt_embeds.shape[0]
305
+
306
+ if prompt_embeds is None:
307
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
308
+ prompt=prompt,
309
+ num_videos_per_prompt=num_videos_per_prompt,
310
+ max_sequence_length=max_sequence_length,
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+
315
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
316
+ negative_prompt = negative_prompt or ""
317
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
318
+
319
+ if prompt is not None and type(prompt) is not type(negative_prompt):
320
+ raise TypeError(
321
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
322
+ f" {type(prompt)}."
323
+ )
324
+ elif batch_size != len(negative_prompt):
325
+ raise ValueError(
326
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
327
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
328
+ " the batch size of `prompt`."
329
+ )
330
+
331
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
332
+ prompt=negative_prompt,
333
+ num_videos_per_prompt=num_videos_per_prompt,
334
+ max_sequence_length=max_sequence_length,
335
+ device=device,
336
+ dtype=dtype,
337
+ )
338
+
339
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
340
+
341
+ def check_inputs(
342
+ self,
343
+ prompt,
344
+ height,
345
+ width,
346
+ callback_on_step_end_tensor_inputs=None,
347
+ prompt_embeds=None,
348
+ negative_prompt_embeds=None,
349
+ prompt_attention_mask=None,
350
+ negative_prompt_attention_mask=None,
351
+ ):
352
+ if height % 8 != 0 or width % 8 != 0:
353
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
354
+
355
+ if callback_on_step_end_tensor_inputs is not None and not all(
356
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
357
+ ):
358
+ raise ValueError(
359
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
360
+ )
361
+
362
+ if prompt is not None and prompt_embeds is not None:
363
+ raise ValueError(
364
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
365
+ " only forward one of the two."
366
+ )
367
+ elif prompt is None and prompt_embeds is None:
368
+ raise ValueError(
369
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
370
+ )
371
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
372
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
373
+
374
+ if prompt_embeds is not None and prompt_attention_mask is None:
375
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
376
+
377
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
378
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
379
+
380
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
381
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
382
+ raise ValueError(
383
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
384
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
385
+ f" {negative_prompt_embeds.shape}."
386
+ )
387
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
388
+ raise ValueError(
389
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
390
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
391
+ f" {negative_prompt_attention_mask.shape}."
392
+ )
393
+
394
+ def enable_vae_slicing(self):
395
+ r"""
396
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
397
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
398
+ """
399
+ self.vae.enable_slicing()
400
+
401
+ def disable_vae_slicing(self):
402
+ r"""
403
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
404
+ computing decoding in one step.
405
+ """
406
+ self.vae.disable_slicing()
407
+
408
+ def enable_vae_tiling(self):
409
+ r"""
410
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
411
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
412
+ processing larger images.
413
+ """
414
+ self.vae.enable_tiling()
415
+
416
+ def disable_vae_tiling(self):
417
+ r"""
418
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
419
+ computing decoding in one step.
420
+ """
421
+ self.vae.disable_tiling()
422
+
423
+ def prepare_latents(
424
+ self,
425
+ batch_size,
426
+ num_channels_latents,
427
+ height,
428
+ width,
429
+ num_frames,
430
+ dtype,
431
+ device,
432
+ generator,
433
+ latents=None,
434
+ ):
435
+ height = height // self.vae_spatial_scale_factor
436
+ width = width // self.vae_spatial_scale_factor
437
+ num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
438
+
439
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
440
+
441
+ if latents is not None:
442
+ return latents.to(device=device, dtype=dtype)
443
+ if isinstance(generator, list) and len(generator) != batch_size:
444
+ raise ValueError(
445
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
446
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
447
+ )
448
+
449
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
450
+ latents = latents.to(dtype)
451
+ return latents
452
+
453
+ @property
454
+ def guidance_scale(self):
455
+ return self._guidance_scale
456
+
457
+ @property
458
+ def do_classifier_free_guidance(self):
459
+ return self._guidance_scale > 1.0
460
+
461
+ @property
462
+ def num_timesteps(self):
463
+ return self._num_timesteps
464
+
465
+ @property
466
+ def attention_kwargs(self):
467
+ return self._attention_kwargs
468
+
469
+ @property
470
+ def current_timestep(self):
471
+ return self._current_timestep
472
+
473
+ @property
474
+ def interrupt(self):
475
+ return self._interrupt
476
+
477
+ @torch.no_grad()
478
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
479
+ def __call__(
480
+ self,
481
+ prompt: Union[str, List[str]] = None,
482
+ negative_prompt: Optional[Union[str, List[str]]] = None,
483
+ height: Optional[int] = None,
484
+ width: Optional[int] = None,
485
+ num_frames: int = 19,
486
+ num_inference_steps: int = 64,
487
+ timesteps: List[int] = None,
488
+ guidance_scale: float = 4.5,
489
+ num_videos_per_prompt: Optional[int] = 1,
490
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
491
+ latents: Optional[torch.Tensor] = None,
492
+ prompt_embeds: Optional[torch.Tensor] = None,
493
+ prompt_attention_mask: Optional[torch.Tensor] = None,
494
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
495
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
496
+ output_type: Optional[str] = "pil",
497
+ return_dict: bool = True,
498
+ attention_kwargs: Optional[Dict[str, Any]] = None,
499
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
500
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
501
+ max_sequence_length: int = 256,
502
+ ):
503
+ r"""
504
+ Function invoked when calling the pipeline for generation.
505
+
506
+ Args:
507
+ prompt (`str` or `List[str]`, *optional*):
508
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
509
+ instead.
510
+ height (`int`, *optional*, defaults to `self.default_height`):
511
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
512
+ width (`int`, *optional*, defaults to `self.default_width`):
513
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
514
+ num_frames (`int`, defaults to `19`):
515
+ The number of video frames to generate
516
+ num_inference_steps (`int`, *optional*, defaults to 50):
517
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
518
+ expense of slower inference.
519
+ timesteps (`List[int]`, *optional*):
520
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
521
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
522
+ passed will be used. Must be in descending order.
523
+ guidance_scale (`float`, defaults to `4.5`):
524
+ Guidance scale as defined in [Classifier-Free Diffusion
525
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
526
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
527
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
528
+ the text `prompt`, usually at the expense of lower image quality.
529
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
530
+ The number of videos to generate per prompt.
531
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
532
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
533
+ to make generation deterministic.
534
+ latents (`torch.Tensor`, *optional*):
535
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
536
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
537
+ tensor will be generated by sampling using the supplied random `generator`.
538
+ prompt_embeds (`torch.Tensor`, *optional*):
539
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
540
+ provided, text embeddings will be generated from `prompt` input argument.
541
+ prompt_attention_mask (`torch.Tensor`, *optional*):
542
+ Pre-generated attention mask for text embeddings.
543
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
544
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
545
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
546
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
547
+ Pre-generated attention mask for negative text embeddings.
548
+ output_type (`str`, *optional*, defaults to `"pil"`):
549
+ The output format of the generate image. Choose between
550
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
551
+ return_dict (`bool`, *optional*, defaults to `True`):
552
+ Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
553
+ attention_kwargs (`dict`, *optional*):
554
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
555
+ `self.processor` in
556
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
557
+ callback_on_step_end (`Callable`, *optional*):
558
+ A function that calls at the end of each denoising steps during the inference. The function is called
559
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
560
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
561
+ `callback_on_step_end_tensor_inputs`.
562
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
563
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
564
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
565
+ `._callback_tensor_inputs` attribute of your pipeline class.
566
+ max_sequence_length (`int` defaults to `256`):
567
+ Maximum sequence length to use with the `prompt`.
568
+
569
+ Examples:
570
+
571
+ Returns:
572
+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
573
+ If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
574
+ is returned where the first element is a list with the generated images.
575
+ """
576
+
577
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
578
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
579
+
580
+ height = height or self.default_height
581
+ width = width or self.default_width
582
+
583
+ # 1. Check inputs. Raise error if not correct
584
+ self.check_inputs(
585
+ prompt=prompt,
586
+ height=height,
587
+ width=width,
588
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
589
+ prompt_embeds=prompt_embeds,
590
+ negative_prompt_embeds=negative_prompt_embeds,
591
+ prompt_attention_mask=prompt_attention_mask,
592
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
593
+ )
594
+
595
+ self._guidance_scale = guidance_scale
596
+ self._attention_kwargs = attention_kwargs
597
+ self._current_timestep = None
598
+ self._interrupt = False
599
+
600
+ # 2. Define call parameters
601
+ if prompt is not None and isinstance(prompt, str):
602
+ batch_size = 1
603
+ elif prompt is not None and isinstance(prompt, list):
604
+ batch_size = len(prompt)
605
+ else:
606
+ batch_size = prompt_embeds.shape[0]
607
+
608
+ device = self._execution_device
609
+ # 3. Prepare text embeddings
610
+ (
611
+ prompt_embeds,
612
+ prompt_attention_mask,
613
+ negative_prompt_embeds,
614
+ negative_prompt_attention_mask,
615
+ ) = self.encode_prompt(
616
+ prompt=prompt,
617
+ negative_prompt=negative_prompt,
618
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
619
+ num_videos_per_prompt=num_videos_per_prompt,
620
+ prompt_embeds=prompt_embeds,
621
+ negative_prompt_embeds=negative_prompt_embeds,
622
+ prompt_attention_mask=prompt_attention_mask,
623
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
624
+ max_sequence_length=max_sequence_length,
625
+ device=device,
626
+ )
627
+ # 4. Prepare latent variables
628
+ num_channels_latents = self.transformer.config.in_channels
629
+ latents = self.prepare_latents(
630
+ batch_size * num_videos_per_prompt,
631
+ num_channels_latents,
632
+ height,
633
+ width,
634
+ num_frames,
635
+ prompt_embeds.dtype,
636
+ device,
637
+ generator,
638
+ latents,
639
+ )
640
+
641
+ if self.do_classifier_free_guidance:
642
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
643
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
644
+
645
+ # 5. Prepare timestep
646
+ # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
647
+ threshold_noise = 0.025
648
+ sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
649
+ sigmas = np.array(sigmas)
650
+
651
+ timesteps, num_inference_steps = retrieve_timesteps(
652
+ self.scheduler,
653
+ num_inference_steps,
654
+ device,
655
+ timesteps,
656
+ sigmas,
657
+ )
658
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
659
+ self._num_timesteps = len(timesteps)
660
+
661
+ # 6. Denoising loop
662
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
663
+ for i, t in enumerate(timesteps):
664
+ if self.interrupt:
665
+ continue
666
+
667
+ # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
668
+ # to make sure we're using the correct non-reversed timestep values.
669
+ self._current_timestep = 1000 - t
670
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
671
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
672
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
673
+
674
+ with self.transformer.cache_context("cond_uncond"):
675
+ noise_pred = self.transformer(
676
+ hidden_states=latent_model_input,
677
+ encoder_hidden_states=prompt_embeds,
678
+ timestep=timestep,
679
+ encoder_attention_mask=prompt_attention_mask,
680
+ attention_kwargs=attention_kwargs,
681
+ return_dict=False,
682
+ )[0]
683
+ # Mochi CFG + Sampling runs in FP32
684
+ noise_pred = noise_pred.to(torch.float32)
685
+
686
+ if self.do_classifier_free_guidance:
687
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
688
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
689
+
690
+ # compute the previous noisy sample x_t -> x_t-1
691
+ latents_dtype = latents.dtype
692
+ latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
693
+ latents = latents.to(latents_dtype)
694
+
695
+ if latents.dtype != latents_dtype:
696
+ if torch.backends.mps.is_available():
697
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
698
+ latents = latents.to(latents_dtype)
699
+
700
+ if callback_on_step_end is not None:
701
+ callback_kwargs = {}
702
+ for k in callback_on_step_end_tensor_inputs:
703
+ callback_kwargs[k] = locals()[k]
704
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
705
+
706
+ latents = callback_outputs.pop("latents", latents)
707
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
708
+
709
+ # call the callback, if provided
710
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
711
+ progress_bar.update()
712
+
713
+ if XLA_AVAILABLE:
714
+ xm.mark_step()
715
+
716
+ self._current_timestep = None
717
+
718
+ if output_type == "latent":
719
+ video = latents
720
+ else:
721
+ # unscale/denormalize the latents
722
+ # denormalize with the mean and std if available and not None
723
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
724
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
725
+ if has_latents_mean and has_latents_std:
726
+ latents_mean = (
727
+ torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
728
+ )
729
+ latents_std = (
730
+ torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
731
+ )
732
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
733
+ else:
734
+ latents = latents / self.vae.config.scaling_factor
735
+
736
+ video = self.vae.decode(latents, return_dict=False)[0]
737
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
738
+
739
+ # Offload all models
740
+ self.maybe_free_model_hooks()
741
+
742
+ if not return_dict:
743
+ return (video,)
744
+
745
+ return MochiPipelineOutput(frames=video)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/mochi/pipeline_output.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from diffusers.utils import BaseOutput
6
+
7
+
8
+ @dataclass
9
+ class MochiPipelineOutput(BaseOutput):
10
+ r"""
11
+ Output class for Mochi pipelines.
12
+
13
+ Args:
14
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17
+ `(batch_size, num_frames, channels, height, width)`.
18
+ """
19
+
20
+ frames: torch.Tensor
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ is_transformers_version,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"]
26
+
27
+
28
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29
+ try:
30
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
31
+ raise OptionalDependencyNotAvailable()
32
+
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+ else:
36
+ from .pipeline_musicldm import MusicLDMPipeline
37
+
38
+ else:
39
+ import sys
40
+
41
+ sys.modules[__name__] = _LazyModule(
42
+ __name__,
43
+ globals()["__file__"],
44
+ _import_structure,
45
+ module_spec=__spec__,
46
+ )
47
+
48
+ for name, value in _dummy_objects.items():
49
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/__pycache__/pipeline_musicldm.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/musicldm/pipeline_musicldm.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ ClapFeatureExtractor,
22
+ ClapModel,
23
+ ClapTextModelWithProjection,
24
+ RobertaTokenizer,
25
+ RobertaTokenizerFast,
26
+ SpeechT5HifiGan,
27
+ )
28
+
29
+ from ...models import AutoencoderKL, UNet2DConditionModel
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ is_accelerate_available,
33
+ is_accelerate_version,
34
+ is_librosa_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ )
38
+ from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor
39
+ from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
40
+
41
+
42
+ if is_librosa_available():
43
+ import librosa
44
+
45
+
46
+ from ...utils import is_torch_xla_available
47
+
48
+
49
+ if is_torch_xla_available():
50
+ import torch_xla.core.xla_model as xm
51
+
52
+ XLA_AVAILABLE = True
53
+ else:
54
+ XLA_AVAILABLE = False
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ EXAMPLE_DOC_STRING = """
60
+ Examples:
61
+ ```py
62
+ >>> from diffusers import MusicLDMPipeline
63
+ >>> import torch
64
+ >>> import scipy
65
+
66
+ >>> repo_id = "ucsd-reach/musicldm"
67
+ >>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
68
+ >>> pipe = pipe.to("cuda")
69
+
70
+ >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
71
+ >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
72
+
73
+ >>> # save the audio sample as a .wav file
74
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
75
+ ```
76
+ """
77
+
78
+
79
+ class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
80
+ _last_supported_version = "0.33.1"
81
+ r"""
82
+ Pipeline for text-to-audio generation using MusicLDM.
83
+
84
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
85
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
86
+
87
+ Args:
88
+ vae ([`AutoencoderKL`]):
89
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
90
+ text_encoder ([`~transformers.ClapModel`]):
91
+ Frozen text-audio embedding model (`ClapTextModel`), specifically the
92
+ [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
93
+ tokenizer ([`PreTrainedTokenizer`]):
94
+ A [`~transformers.RobertaTokenizer`] to tokenize text.
95
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
96
+ Feature extractor to compute mel-spectrograms from audio waveforms.
97
+ unet ([`UNet2DConditionModel`]):
98
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
99
+ scheduler ([`SchedulerMixin`]):
100
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
101
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
102
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
103
+ Vocoder of class `SpeechT5HifiGan`.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ vae: AutoencoderKL,
109
+ text_encoder: Union[ClapTextModelWithProjection, ClapModel],
110
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
111
+ feature_extractor: Optional[ClapFeatureExtractor],
112
+ unet: UNet2DConditionModel,
113
+ scheduler: KarrasDiffusionSchedulers,
114
+ vocoder: SpeechT5HifiGan,
115
+ ):
116
+ super().__init__()
117
+
118
+ self.register_modules(
119
+ vae=vae,
120
+ text_encoder=text_encoder,
121
+ tokenizer=tokenizer,
122
+ feature_extractor=feature_extractor,
123
+ unet=unet,
124
+ scheduler=scheduler,
125
+ vocoder=vocoder,
126
+ )
127
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
128
+
129
+ def _encode_prompt(
130
+ self,
131
+ prompt,
132
+ device,
133
+ num_waveforms_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt=None,
136
+ prompt_embeds: Optional[torch.Tensor] = None,
137
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
138
+ ):
139
+ r"""
140
+ Encodes the prompt into text encoder hidden states.
141
+
142
+ Args:
143
+ prompt (`str` or `List[str]`, *optional*):
144
+ prompt to be encoded
145
+ device (`torch.device`):
146
+ torch device
147
+ num_waveforms_per_prompt (`int`):
148
+ number of waveforms that should be generated per prompt
149
+ do_classifier_free_guidance (`bool`):
150
+ whether to use classifier free guidance or not
151
+ negative_prompt (`str` or `List[str]`, *optional*):
152
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
153
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
154
+ less than `1`).
155
+ prompt_embeds (`torch.Tensor`, *optional*):
156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
157
+ provided, text embeddings will be generated from `prompt` input argument.
158
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
159
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
160
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
161
+ argument.
162
+ """
163
+ if prompt is not None and isinstance(prompt, str):
164
+ batch_size = 1
165
+ elif prompt is not None and isinstance(prompt, list):
166
+ batch_size = len(prompt)
167
+ else:
168
+ batch_size = prompt_embeds.shape[0]
169
+
170
+ if prompt_embeds is None:
171
+ text_inputs = self.tokenizer(
172
+ prompt,
173
+ padding="max_length",
174
+ max_length=self.tokenizer.model_max_length,
175
+ truncation=True,
176
+ return_tensors="pt",
177
+ )
178
+ text_input_ids = text_inputs.input_ids
179
+ attention_mask = text_inputs.attention_mask
180
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
181
+
182
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
183
+ text_input_ids, untruncated_ids
184
+ ):
185
+ removed_text = self.tokenizer.batch_decode(
186
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
187
+ )
188
+ logger.warning(
189
+ "The following part of your input was truncated because CLAP can only handle sequences up to"
190
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
191
+ )
192
+
193
+ prompt_embeds = self.text_encoder.get_text_features(
194
+ text_input_ids.to(device),
195
+ attention_mask=attention_mask.to(device),
196
+ )
197
+
198
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
199
+
200
+ (
201
+ bs_embed,
202
+ seq_len,
203
+ ) = prompt_embeds.shape
204
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
205
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
206
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
207
+
208
+ # get unconditional embeddings for classifier free guidance
209
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
210
+ uncond_tokens: List[str]
211
+ if negative_prompt is None:
212
+ uncond_tokens = [""] * batch_size
213
+ elif type(prompt) is not type(negative_prompt):
214
+ raise TypeError(
215
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
216
+ f" {type(prompt)}."
217
+ )
218
+ elif isinstance(negative_prompt, str):
219
+ uncond_tokens = [negative_prompt]
220
+ elif batch_size != len(negative_prompt):
221
+ raise ValueError(
222
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
223
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
224
+ " the batch size of `prompt`."
225
+ )
226
+ else:
227
+ uncond_tokens = negative_prompt
228
+
229
+ max_length = prompt_embeds.shape[1]
230
+ uncond_input = self.tokenizer(
231
+ uncond_tokens,
232
+ padding="max_length",
233
+ max_length=max_length,
234
+ truncation=True,
235
+ return_tensors="pt",
236
+ )
237
+
238
+ uncond_input_ids = uncond_input.input_ids.to(device)
239
+ attention_mask = uncond_input.attention_mask.to(device)
240
+
241
+ negative_prompt_embeds = self.text_encoder.get_text_features(
242
+ uncond_input_ids,
243
+ attention_mask=attention_mask,
244
+ )
245
+
246
+ if do_classifier_free_guidance:
247
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
248
+ seq_len = negative_prompt_embeds.shape[1]
249
+
250
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
251
+
252
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
253
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
254
+
255
+ # For classifier free guidance, we need to do two forward passes.
256
+ # Here we concatenate the unconditional and text embeddings into a single batch
257
+ # to avoid doing two forward passes
258
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
259
+
260
+ return prompt_embeds
261
+
262
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
263
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
264
+ if mel_spectrogram.dim() == 4:
265
+ mel_spectrogram = mel_spectrogram.squeeze(1)
266
+
267
+ waveform = self.vocoder(mel_spectrogram)
268
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
269
+ waveform = waveform.cpu().float()
270
+ return waveform
271
+
272
+ # Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms
273
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
274
+ if not is_librosa_available():
275
+ logger.info(
276
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
277
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
278
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
279
+ )
280
+ return audio
281
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
282
+ resampled_audio = librosa.resample(
283
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
284
+ )
285
+ inputs["input_features"] = self.feature_extractor(
286
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
287
+ ).input_features.type(dtype)
288
+ inputs = inputs.to(device)
289
+
290
+ # compute the audio-text similarity score using the CLAP model
291
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
292
+ # sort by the highest matching generations per prompt
293
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
294
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
295
+ return audio
296
+
297
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
298
+ def prepare_extra_step_kwargs(self, generator, eta):
299
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
300
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
301
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
302
+ # and should be between [0, 1]
303
+
304
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
305
+ extra_step_kwargs = {}
306
+ if accepts_eta:
307
+ extra_step_kwargs["eta"] = eta
308
+
309
+ # check if the scheduler accepts generator
310
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
311
+ if accepts_generator:
312
+ extra_step_kwargs["generator"] = generator
313
+ return extra_step_kwargs
314
+
315
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs
316
+ def check_inputs(
317
+ self,
318
+ prompt,
319
+ audio_length_in_s,
320
+ vocoder_upsample_factor,
321
+ callback_steps,
322
+ negative_prompt=None,
323
+ prompt_embeds=None,
324
+ negative_prompt_embeds=None,
325
+ ):
326
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
327
+ if audio_length_in_s < min_audio_length_in_s:
328
+ raise ValueError(
329
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
330
+ f"is {audio_length_in_s}."
331
+ )
332
+
333
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
334
+ raise ValueError(
335
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
336
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
337
+ f"{self.vae_scale_factor}."
338
+ )
339
+
340
+ if (callback_steps is None) or (
341
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
342
+ ):
343
+ raise ValueError(
344
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
345
+ f" {type(callback_steps)}."
346
+ )
347
+
348
+ if prompt is not None and prompt_embeds is not None:
349
+ raise ValueError(
350
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
351
+ " only forward one of the two."
352
+ )
353
+ elif prompt is None and prompt_embeds is None:
354
+ raise ValueError(
355
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
356
+ )
357
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
358
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
359
+
360
+ if negative_prompt is not None and negative_prompt_embeds is not None:
361
+ raise ValueError(
362
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
363
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
364
+ )
365
+
366
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
367
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
368
+ raise ValueError(
369
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
370
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
371
+ f" {negative_prompt_embeds.shape}."
372
+ )
373
+
374
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents
375
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
376
+ shape = (
377
+ batch_size,
378
+ num_channels_latents,
379
+ int(height) // self.vae_scale_factor,
380
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
381
+ )
382
+ if isinstance(generator, list) and len(generator) != batch_size:
383
+ raise ValueError(
384
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
385
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
386
+ )
387
+
388
+ if latents is None:
389
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
390
+ else:
391
+ latents = latents.to(device)
392
+
393
+ # scale the initial noise by the standard deviation required by the scheduler
394
+ latents = latents * self.scheduler.init_noise_sigma
395
+ return latents
396
+
397
+ def enable_model_cpu_offload(self, gpu_id=0):
398
+ r"""
399
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
400
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
401
+ `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
402
+ lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
403
+ of the `unet`.
404
+ """
405
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
406
+ from accelerate import cpu_offload_with_hook
407
+ else:
408
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
409
+
410
+ device_type = get_device()
411
+ device = torch.device(f"{device_type}:{gpu_id}")
412
+
413
+ if self.device.type != "cpu":
414
+ self.to("cpu", silence_dtype_warnings=True)
415
+ empty_device_cache() # otherwise we don't see the memory savings (but they probably exist)
416
+
417
+ model_sequence = [
418
+ self.text_encoder.text_model,
419
+ self.text_encoder.text_projection,
420
+ self.unet,
421
+ self.vae,
422
+ self.vocoder,
423
+ self.text_encoder,
424
+ ]
425
+
426
+ hook = None
427
+ for cpu_offloaded_model in model_sequence:
428
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
429
+
430
+ # We'll offload the last model manually.
431
+ self.final_offload_hook = hook
432
+
433
+ @torch.no_grad()
434
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
435
+ def __call__(
436
+ self,
437
+ prompt: Union[str, List[str]] = None,
438
+ audio_length_in_s: Optional[float] = None,
439
+ num_inference_steps: int = 200,
440
+ guidance_scale: float = 2.0,
441
+ negative_prompt: Optional[Union[str, List[str]]] = None,
442
+ num_waveforms_per_prompt: Optional[int] = 1,
443
+ eta: float = 0.0,
444
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
445
+ latents: Optional[torch.Tensor] = None,
446
+ prompt_embeds: Optional[torch.Tensor] = None,
447
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
448
+ return_dict: bool = True,
449
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
450
+ callback_steps: Optional[int] = 1,
451
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
452
+ output_type: Optional[str] = "np",
453
+ ):
454
+ r"""
455
+ The call function to the pipeline for generation.
456
+
457
+ Args:
458
+ prompt (`str` or `List[str]`, *optional*):
459
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
460
+ audio_length_in_s (`int`, *optional*, defaults to 10.24):
461
+ The length of the generated audio sample in seconds.
462
+ num_inference_steps (`int`, *optional*, defaults to 200):
463
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
464
+ expense of slower inference.
465
+ guidance_scale (`float`, *optional*, defaults to 2.0):
466
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
467
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
468
+ negative_prompt (`str` or `List[str]`, *optional*):
469
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
470
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
471
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
472
+ The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding
473
+ model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a
474
+ `[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs
475
+ and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text
476
+ input in the joint text-audio embedding space.
477
+ eta (`float`, *optional*, defaults to 0.0):
478
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
479
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
480
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
481
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
482
+ generation deterministic.
483
+ latents (`torch.Tensor`, *optional*):
484
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
485
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
486
+ tensor is generated by sampling using the supplied random `generator`.
487
+ prompt_embeds (`torch.Tensor`, *optional*):
488
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
489
+ provided, text embeddings are generated from the `prompt` input argument.
490
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
491
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
492
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
493
+ return_dict (`bool`, *optional*, defaults to `True`):
494
+ Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
495
+ callback (`Callable`, *optional*):
496
+ A function that calls every `callback_steps` steps during inference. The function is called with the
497
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
498
+ callback_steps (`int`, *optional*, defaults to 1):
499
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
500
+ every step.
501
+ cross_attention_kwargs (`dict`, *optional*):
502
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
503
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
504
+ output_type (`str`, *optional*, defaults to `"np"`):
505
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
506
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
507
+ model (LDM) output.
508
+
509
+ Examples:
510
+
511
+ Returns:
512
+ [`~pipelines.AudioPipelineOutput`] or `tuple`:
513
+ If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
514
+ returned where the first element is a list with the generated audio.
515
+ """
516
+ # 0. Convert audio input length from seconds to spectrogram height
517
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
518
+
519
+ if audio_length_in_s is None:
520
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
521
+
522
+ height = int(audio_length_in_s / vocoder_upsample_factor)
523
+
524
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
525
+ if height % self.vae_scale_factor != 0:
526
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
527
+ logger.info(
528
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
529
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
530
+ f"denoising process."
531
+ )
532
+
533
+ # 1. Check inputs. Raise error if not correct
534
+ self.check_inputs(
535
+ prompt,
536
+ audio_length_in_s,
537
+ vocoder_upsample_factor,
538
+ callback_steps,
539
+ negative_prompt,
540
+ prompt_embeds,
541
+ negative_prompt_embeds,
542
+ )
543
+
544
+ # 2. Define call parameters
545
+ if prompt is not None and isinstance(prompt, str):
546
+ batch_size = 1
547
+ elif prompt is not None and isinstance(prompt, list):
548
+ batch_size = len(prompt)
549
+ else:
550
+ batch_size = prompt_embeds.shape[0]
551
+
552
+ device = self._execution_device
553
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
554
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
555
+ # corresponds to doing no classifier free guidance.
556
+ do_classifier_free_guidance = guidance_scale > 1.0
557
+
558
+ # 3. Encode input prompt
559
+ prompt_embeds = self._encode_prompt(
560
+ prompt,
561
+ device,
562
+ num_waveforms_per_prompt,
563
+ do_classifier_free_guidance,
564
+ negative_prompt,
565
+ prompt_embeds=prompt_embeds,
566
+ negative_prompt_embeds=negative_prompt_embeds,
567
+ )
568
+
569
+ # 4. Prepare timesteps
570
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
571
+ timesteps = self.scheduler.timesteps
572
+
573
+ # 5. Prepare latent variables
574
+ num_channels_latents = self.unet.config.in_channels
575
+ latents = self.prepare_latents(
576
+ batch_size * num_waveforms_per_prompt,
577
+ num_channels_latents,
578
+ height,
579
+ prompt_embeds.dtype,
580
+ device,
581
+ generator,
582
+ latents,
583
+ )
584
+
585
+ # 6. Prepare extra step kwargs
586
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
587
+
588
+ # 7. Denoising loop
589
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
590
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
591
+ for i, t in enumerate(timesteps):
592
+ # expand the latents if we are doing classifier free guidance
593
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
594
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
595
+
596
+ # predict the noise residual
597
+ noise_pred = self.unet(
598
+ latent_model_input,
599
+ t,
600
+ encoder_hidden_states=None,
601
+ class_labels=prompt_embeds,
602
+ cross_attention_kwargs=cross_attention_kwargs,
603
+ return_dict=False,
604
+ )[0]
605
+
606
+ # perform guidance
607
+ if do_classifier_free_guidance:
608
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
609
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
610
+
611
+ # compute the previous noisy sample x_t -> x_t-1
612
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
613
+
614
+ # call the callback, if provided
615
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
616
+ progress_bar.update()
617
+ if callback is not None and i % callback_steps == 0:
618
+ step_idx = i // getattr(self.scheduler, "order", 1)
619
+ callback(step_idx, t, latents)
620
+
621
+ if XLA_AVAILABLE:
622
+ xm.mark_step()
623
+
624
+ self.maybe_free_model_hooks()
625
+
626
+ # 8. Post-processing
627
+ if not output_type == "latent":
628
+ latents = 1 / self.vae.config.scaling_factor * latents
629
+ mel_spectrogram = self.vae.decode(latents).sample
630
+ else:
631
+ return AudioPipelineOutput(audios=latents)
632
+
633
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
634
+
635
+ audio = audio[:, :original_waveform_length]
636
+
637
+ # 9. Automatic scoring
638
+ if num_waveforms_per_prompt > 1 and prompt is not None:
639
+ audio = self.score_waveforms(
640
+ text=prompt,
641
+ audio=audio,
642
+ num_waveforms_per_prompt=num_waveforms_per_prompt,
643
+ device=device,
644
+ dtype=prompt_embeds.dtype,
645
+ )
646
+
647
+ if output_type == "np":
648
+ audio = audio.numpy()
649
+
650
+ if not return_dict:
651
+ return (audio,)
652
+
653
+ return AudioPipelineOutput(audios=audio)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
26
+
27
+
28
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29
+ try:
30
+ if not (is_transformers_available() and is_torch_available()):
31
+ raise OptionalDependencyNotAvailable()
32
+
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+ else:
36
+ from .pipeline_omnigen import OmniGenPipeline
37
+
38
+
39
+ else:
40
+ import sys
41
+
42
+ sys.modules[__name__] = _LazyModule(
43
+ __name__,
44
+ globals()["__file__"],
45
+ _import_structure,
46
+ module_spec=__spec__,
47
+ )
48
+
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/pipeline_omnigen.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/__pycache__/processor_omnigen.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/pipeline_omnigen.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import LlamaTokenizer
21
+
22
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
23
+ from ...models.autoencoders import AutoencoderKL
24
+ from ...models.transformers import OmniGenTransformer2DModel
25
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
26
+ from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
27
+ from ...utils.torch_utils import randn_tensor
28
+ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
+
30
+
31
+ if is_torchvision_available():
32
+ from .processor_omnigen import OmniGenMultiModalProcessor
33
+
34
+ if is_torch_xla_available():
35
+ XLA_AVAILABLE = True
36
+ else:
37
+ XLA_AVAILABLE = False
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> import torch
45
+ >>> from diffusers import OmniGenPipeline
46
+
47
+ >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
48
+ >>> pipe.to("cuda")
49
+
50
+ >>> prompt = "A cat holding a sign that says hello world"
51
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
52
+ >>> # Refer to the pipeline documentation for more details.
53
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
54
+ >>> image.save("output.png")
55
+ ```
56
+ """
57
+
58
+
59
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
60
+ def retrieve_timesteps(
61
+ scheduler,
62
+ num_inference_steps: Optional[int] = None,
63
+ device: Optional[Union[str, torch.device]] = None,
64
+ timesteps: Optional[List[int]] = None,
65
+ sigmas: Optional[List[float]] = None,
66
+ **kwargs,
67
+ ):
68
+ r"""
69
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
70
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
71
+
72
+ Args:
73
+ scheduler (`SchedulerMixin`):
74
+ The scheduler to get timesteps from.
75
+ num_inference_steps (`int`):
76
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
77
+ must be `None`.
78
+ device (`str` or `torch.device`, *optional*):
79
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
80
+ timesteps (`List[int]`, *optional*):
81
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
82
+ `num_inference_steps` and `sigmas` must be `None`.
83
+ sigmas (`List[float]`, *optional*):
84
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
85
+ `num_inference_steps` and `timesteps` must be `None`.
86
+
87
+ Returns:
88
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
89
+ second element is the number of inference steps.
90
+ """
91
+ if timesteps is not None and sigmas is not None:
92
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
93
+ if timesteps is not None:
94
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
95
+ if not accepts_timesteps:
96
+ raise ValueError(
97
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
98
+ f" timestep schedules. Please check whether you are using the correct scheduler."
99
+ )
100
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
101
+ timesteps = scheduler.timesteps
102
+ num_inference_steps = len(timesteps)
103
+ elif sigmas is not None:
104
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105
+ if not accept_sigmas:
106
+ raise ValueError(
107
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
108
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
109
+ )
110
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
111
+ timesteps = scheduler.timesteps
112
+ num_inference_steps = len(timesteps)
113
+ else:
114
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
115
+ timesteps = scheduler.timesteps
116
+ return timesteps, num_inference_steps
117
+
118
+
119
+ class OmniGenPipeline(
120
+ DiffusionPipeline,
121
+ ):
122
+ r"""
123
+ The OmniGen pipeline for multimodal-to-image generation.
124
+
125
+ Reference: https://huggingface.co/papers/2409.11340
126
+
127
+ Args:
128
+ transformer ([`OmniGenTransformer2DModel`]):
129
+ Autoregressive Transformer architecture for OmniGen.
130
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
131
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
132
+ vae ([`AutoencoderKL`]):
133
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
134
+ tokenizer (`LlamaTokenizer`):
135
+ Text tokenizer of class.
136
+ [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer).
137
+ """
138
+
139
+ model_cpu_offload_seq = "transformer->vae"
140
+ _optional_components = []
141
+ _callback_tensor_inputs = ["latents"]
142
+
143
+ def __init__(
144
+ self,
145
+ transformer: OmniGenTransformer2DModel,
146
+ scheduler: FlowMatchEulerDiscreteScheduler,
147
+ vae: AutoencoderKL,
148
+ tokenizer: LlamaTokenizer,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.register_modules(
153
+ vae=vae,
154
+ tokenizer=tokenizer,
155
+ transformer=transformer,
156
+ scheduler=scheduler,
157
+ )
158
+ self.vae_scale_factor = (
159
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8
160
+ )
161
+ # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
162
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
163
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
164
+
165
+ self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024)
166
+ self.tokenizer_max_length = (
167
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000
168
+ )
169
+ self.default_sample_size = 128
170
+
171
+ def encode_input_images(
172
+ self,
173
+ input_pixel_values: List[torch.Tensor],
174
+ device: Optional[torch.device] = None,
175
+ dtype: Optional[torch.dtype] = None,
176
+ ):
177
+ """
178
+ get the continue embedding of input images by VAE
179
+
180
+ Args:
181
+ input_pixel_values: normalized pixel of input images
182
+ device:
183
+ Returns: torch.Tensor
184
+ """
185
+ device = device or self._execution_device
186
+ dtype = dtype or self.vae.dtype
187
+
188
+ input_img_latents = []
189
+ for img in input_pixel_values:
190
+ img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor)
191
+ input_img_latents.append(img)
192
+ return input_img_latents
193
+
194
+ def check_inputs(
195
+ self,
196
+ prompt,
197
+ input_images,
198
+ height,
199
+ width,
200
+ use_input_image_size_as_output,
201
+ callback_on_step_end_tensor_inputs=None,
202
+ ):
203
+ if input_images is not None:
204
+ if len(input_images) != len(prompt):
205
+ raise ValueError(
206
+ f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
207
+ )
208
+ for i in range(len(input_images)):
209
+ if input_images[i] is not None:
210
+ if not all(f"<img><|image_{k + 1}|></img>" in prompt[i] for k in range(len(input_images[i]))):
211
+ raise ValueError(
212
+ f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
213
+ )
214
+
215
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
216
+ logger.warning(
217
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
218
+ )
219
+
220
+ if use_input_image_size_as_output:
221
+ if input_images is None or input_images[0] is None:
222
+ raise ValueError(
223
+ "`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
224
+ )
225
+
226
+ if callback_on_step_end_tensor_inputs is not None and not all(
227
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
228
+ ):
229
+ raise ValueError(
230
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
231
+ )
232
+
233
+ def enable_vae_slicing(self):
234
+ r"""
235
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
236
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
237
+ """
238
+ self.vae.enable_slicing()
239
+
240
+ def disable_vae_slicing(self):
241
+ r"""
242
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
243
+ computing decoding in one step.
244
+ """
245
+ self.vae.disable_slicing()
246
+
247
+ def enable_vae_tiling(self):
248
+ r"""
249
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
250
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
251
+ processing larger images.
252
+ """
253
+ self.vae.enable_tiling()
254
+
255
+ def disable_vae_tiling(self):
256
+ r"""
257
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
258
+ computing decoding in one step.
259
+ """
260
+ self.vae.disable_tiling()
261
+
262
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
263
+ def prepare_latents(
264
+ self,
265
+ batch_size,
266
+ num_channels_latents,
267
+ height,
268
+ width,
269
+ dtype,
270
+ device,
271
+ generator,
272
+ latents=None,
273
+ ):
274
+ if latents is not None:
275
+ return latents.to(device=device, dtype=dtype)
276
+
277
+ shape = (
278
+ batch_size,
279
+ num_channels_latents,
280
+ int(height) // self.vae_scale_factor,
281
+ int(width) // self.vae_scale_factor,
282
+ )
283
+
284
+ if isinstance(generator, list) and len(generator) != batch_size:
285
+ raise ValueError(
286
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
287
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
288
+ )
289
+
290
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
291
+
292
+ return latents
293
+
294
+ @property
295
+ def guidance_scale(self):
296
+ return self._guidance_scale
297
+
298
+ @property
299
+ def num_timesteps(self):
300
+ return self._num_timesteps
301
+
302
+ @property
303
+ def interrupt(self):
304
+ return self._interrupt
305
+
306
+ @torch.no_grad()
307
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
308
+ def __call__(
309
+ self,
310
+ prompt: Union[str, List[str]],
311
+ input_images: Union[PipelineImageInput, List[PipelineImageInput]] = None,
312
+ height: Optional[int] = None,
313
+ width: Optional[int] = None,
314
+ num_inference_steps: int = 50,
315
+ max_input_image_size: int = 1024,
316
+ timesteps: List[int] = None,
317
+ guidance_scale: float = 2.5,
318
+ img_guidance_scale: float = 1.6,
319
+ use_input_image_size_as_output: bool = False,
320
+ num_images_per_prompt: Optional[int] = 1,
321
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
322
+ latents: Optional[torch.Tensor] = None,
323
+ output_type: Optional[str] = "pil",
324
+ return_dict: bool = True,
325
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
326
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
327
+ ):
328
+ r"""
329
+ Function invoked when calling the pipeline for generation.
330
+
331
+ Args:
332
+ prompt (`str` or `List[str]`, *optional*):
333
+ The prompt or prompts to guide the image generation. If the input includes images, need to add
334
+ placeholders `<img><|image_i|></img>` in the prompt to indicate the position of the i-th images.
335
+ input_images (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
336
+ The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list.
337
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
338
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
339
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
340
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
341
+ num_inference_steps (`int`, *optional*, defaults to 50):
342
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
343
+ expense of slower inference.
344
+ max_input_image_size (`int`, *optional*, defaults to 1024):
345
+ the maximum size of input image, which will be used to crop the input image to the maximum size
346
+ timesteps (`List[int]`, *optional*):
347
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
348
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
349
+ passed will be used. Must be in descending order.
350
+ guidance_scale (`float`, *optional*, defaults to 2.5):
351
+ Guidance scale as defined in [Classifier-Free Diffusion
352
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
353
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
354
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
355
+ the text `prompt`, usually at the expense of lower image quality.
356
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
357
+ Defined as equation 3 in [Instrucpix2pix](https://huggingface.co/papers/2211.09800).
358
+ use_input_image_size_as_output (bool, defaults to False):
359
+ whether to use the input image size as the output image size, which can be used for single-image input,
360
+ e.g., image editing task
361
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
362
+ The number of images to generate per prompt.
363
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
364
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
365
+ to make generation deterministic.
366
+ latents (`torch.Tensor`, *optional*):
367
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
368
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
369
+ tensor will be generated by sampling using the supplied random `generator`.
370
+ output_type (`str`, *optional*, defaults to `"pil"`):
371
+ The output format of the generate image. Choose between
372
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
373
+ return_dict (`bool`, *optional*, defaults to `True`):
374
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
375
+ callback_on_step_end (`Callable`, *optional*):
376
+ A function that calls at the end of each denoising steps during the inference. The function is called
377
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
378
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
379
+ `callback_on_step_end_tensor_inputs`.
380
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
381
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
382
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
383
+ `._callback_tensor_inputs` attribute of your pipeline class.
384
+
385
+ Examples:
386
+
387
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
388
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
389
+ where the first element is a list with the generated images.
390
+ """
391
+
392
+ height = height or self.default_sample_size * self.vae_scale_factor
393
+ width = width or self.default_sample_size * self.vae_scale_factor
394
+ num_cfg = 2 if input_images is not None else 1
395
+ use_img_cfg = True if input_images is not None else False
396
+ if isinstance(prompt, str):
397
+ prompt = [prompt]
398
+ input_images = [input_images]
399
+
400
+ # 1. Check inputs. Raise error if not correct
401
+ self.check_inputs(
402
+ prompt,
403
+ input_images,
404
+ height,
405
+ width,
406
+ use_input_image_size_as_output,
407
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
408
+ )
409
+
410
+ self._guidance_scale = guidance_scale
411
+ self._interrupt = False
412
+
413
+ # 2. Define call parameters
414
+ batch_size = len(prompt)
415
+ device = self._execution_device
416
+
417
+ # 3. process multi-modal instructions
418
+ if max_input_image_size != self.multimodal_processor.max_image_size:
419
+ self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size)
420
+ processed_data = self.multimodal_processor(
421
+ prompt,
422
+ input_images,
423
+ height=height,
424
+ width=width,
425
+ use_img_cfg=use_img_cfg,
426
+ use_input_image_size_as_output=use_input_image_size_as_output,
427
+ num_images_per_prompt=num_images_per_prompt,
428
+ )
429
+ processed_data["input_ids"] = processed_data["input_ids"].to(device)
430
+ processed_data["attention_mask"] = processed_data["attention_mask"].to(device)
431
+ processed_data["position_ids"] = processed_data["position_ids"].to(device)
432
+
433
+ # 4. Encode input images
434
+ input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device)
435
+
436
+ # 5. Prepare timesteps
437
+ sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
438
+ timesteps, num_inference_steps = retrieve_timesteps(
439
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
440
+ )
441
+ self._num_timesteps = len(timesteps)
442
+
443
+ # 6. Prepare latents
444
+ transformer_dtype = self.transformer.dtype
445
+ if use_input_image_size_as_output:
446
+ height, width = processed_data["input_pixel_values"][0].shape[-2:]
447
+ latent_channels = self.transformer.config.in_channels
448
+ latents = self.prepare_latents(
449
+ batch_size * num_images_per_prompt,
450
+ latent_channels,
451
+ height,
452
+ width,
453
+ torch.float32,
454
+ device,
455
+ generator,
456
+ latents,
457
+ )
458
+
459
+ # 8. Denoising loop
460
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
461
+ for i, t in enumerate(timesteps):
462
+ # expand the latents if we are doing classifier free guidance
463
+ latent_model_input = torch.cat([latents] * (num_cfg + 1))
464
+ latent_model_input = latent_model_input.to(transformer_dtype)
465
+
466
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
467
+ timestep = t.expand(latent_model_input.shape[0])
468
+
469
+ noise_pred = self.transformer(
470
+ hidden_states=latent_model_input,
471
+ timestep=timestep,
472
+ input_ids=processed_data["input_ids"],
473
+ input_img_latents=input_img_latents,
474
+ input_image_sizes=processed_data["input_image_sizes"],
475
+ attention_mask=processed_data["attention_mask"],
476
+ position_ids=processed_data["position_ids"],
477
+ return_dict=False,
478
+ )[0]
479
+
480
+ if num_cfg == 2:
481
+ cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
482
+ noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
483
+ else:
484
+ cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
485
+ noise_pred = uncond + guidance_scale * (cond - uncond)
486
+
487
+ # compute the previous noisy sample x_t -> x_t-1
488
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
489
+
490
+ if callback_on_step_end is not None:
491
+ callback_kwargs = {}
492
+ for k in callback_on_step_end_tensor_inputs:
493
+ callback_kwargs[k] = locals()[k]
494
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
495
+
496
+ latents = callback_outputs.pop("latents", latents)
497
+
498
+ progress_bar.update()
499
+
500
+ if not output_type == "latent":
501
+ latents = latents.to(self.vae.dtype)
502
+ latents = latents / self.vae.config.scaling_factor
503
+ image = self.vae.decode(latents, return_dict=False)[0]
504
+ image = self.image_processor.postprocess(image, output_type=output_type)
505
+ else:
506
+ image = latents
507
+
508
+ # Offload all models
509
+ self.maybe_free_model_hooks()
510
+
511
+ if not return_dict:
512
+ return (image,)
513
+
514
+ return ImagePipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/omnigen/processor_omnigen.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List
17
+
18
+ import numpy as np
19
+ import torch
20
+ from PIL import Image
21
+
22
+ from ...utils import is_torchvision_available
23
+
24
+
25
+ if is_torchvision_available():
26
+ from torchvision import transforms
27
+
28
+
29
+ def crop_image(pil_image, max_image_size):
30
+ """
31
+ Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
32
+ width are multiples of 16.
33
+ """
34
+ while min(*pil_image.size) >= 2 * max_image_size:
35
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
36
+
37
+ if max(*pil_image.size) > max_image_size:
38
+ scale = max_image_size / max(*pil_image.size)
39
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
40
+
41
+ if min(*pil_image.size) < 16:
42
+ scale = 16 / min(*pil_image.size)
43
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
44
+
45
+ arr = np.array(pil_image)
46
+ crop_y1 = (arr.shape[0] % 16) // 2
47
+ crop_y2 = arr.shape[0] % 16 - crop_y1
48
+
49
+ crop_x1 = (arr.shape[1] % 16) // 2
50
+ crop_x2 = arr.shape[1] % 16 - crop_x1
51
+
52
+ arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
53
+ return Image.fromarray(arr)
54
+
55
+
56
+ class OmniGenMultiModalProcessor:
57
+ def __init__(self, text_tokenizer, max_image_size: int = 1024):
58
+ self.text_tokenizer = text_tokenizer
59
+ self.max_image_size = max_image_size
60
+
61
+ self.image_transform = transforms.Compose(
62
+ [
63
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
66
+ ]
67
+ )
68
+
69
+ self.collator = OmniGenCollator()
70
+
71
+ def reset_max_image_size(self, max_image_size):
72
+ self.max_image_size = max_image_size
73
+ self.image_transform = transforms.Compose(
74
+ [
75
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
78
+ ]
79
+ )
80
+
81
+ def process_image(self, image):
82
+ if isinstance(image, str):
83
+ image = Image.open(image).convert("RGB")
84
+ return self.image_transform(image)
85
+
86
+ def process_multi_modal_prompt(self, text, input_images):
87
+ text = self.add_prefix_instruction(text)
88
+ if input_images is None or len(input_images) == 0:
89
+ model_inputs = self.text_tokenizer(text)
90
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
91
+
92
+ pattern = r"<\|image_\d+\|>"
93
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
94
+
95
+ for i in range(1, len(prompt_chunks)):
96
+ if prompt_chunks[i][0] == 1:
97
+ prompt_chunks[i] = prompt_chunks[i][1:]
98
+
99
+ image_tags = re.findall(pattern, text)
100
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
101
+
102
+ unique_image_ids = sorted(set(image_ids))
103
+ assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
104
+ f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
105
+ )
106
+ # total images must be the same as the number of image tags
107
+ assert len(unique_image_ids) == len(input_images), (
108
+ f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
109
+ )
110
+
111
+ input_images = [input_images[x - 1] for x in image_ids]
112
+
113
+ all_input_ids = []
114
+ img_inx = []
115
+ for i in range(len(prompt_chunks)):
116
+ all_input_ids.extend(prompt_chunks[i])
117
+ if i != len(prompt_chunks) - 1:
118
+ start_inx = len(all_input_ids)
119
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
120
+ img_inx.append([start_inx, start_inx + size])
121
+ all_input_ids.extend([0] * size)
122
+
123
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
124
+
125
+ def add_prefix_instruction(self, prompt):
126
+ user_prompt = "<|user|>\n"
127
+ generation_prompt = "Generate an image according to the following instructions\n"
128
+ assistant_prompt = "<|assistant|>\n<|diffusion|>"
129
+ prompt_suffix = "<|end|>\n"
130
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
131
+ return prompt
132
+
133
+ def __call__(
134
+ self,
135
+ instructions: List[str],
136
+ input_images: List[List[str]] = None,
137
+ height: int = 1024,
138
+ width: int = 1024,
139
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
140
+ use_img_cfg: bool = True,
141
+ separate_cfg_input: bool = False,
142
+ use_input_image_size_as_output: bool = False,
143
+ num_images_per_prompt: int = 1,
144
+ ) -> Dict:
145
+ if isinstance(instructions, str):
146
+ instructions = [instructions]
147
+ input_images = [input_images]
148
+
149
+ input_data = []
150
+ for i in range(len(instructions)):
151
+ cur_instruction = instructions[i]
152
+ cur_input_images = None if input_images is None else input_images[i]
153
+ if cur_input_images is not None and len(cur_input_images) > 0:
154
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
155
+ else:
156
+ cur_input_images = None
157
+ assert "<img><|image_1|></img>" not in cur_instruction
158
+
159
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
160
+
161
+ neg_mllm_input, img_cfg_mllm_input = None, None
162
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
163
+ if use_img_cfg:
164
+ if cur_input_images is not None and len(cur_input_images) >= 1:
165
+ img_cfg_prompt = [f"<img><|image_{i + 1}|></img>" for i in range(len(cur_input_images))]
166
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
167
+ else:
168
+ img_cfg_mllm_input = neg_mllm_input
169
+
170
+ for _ in range(num_images_per_prompt):
171
+ if use_input_image_size_as_output:
172
+ input_data.append(
173
+ (
174
+ mllm_input,
175
+ neg_mllm_input,
176
+ img_cfg_mllm_input,
177
+ [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
178
+ )
179
+ )
180
+ else:
181
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
182
+
183
+ return self.collator(input_data)
184
+
185
+
186
+ class OmniGenCollator:
187
+ def __init__(self, pad_token_id=2, hidden_size=3072):
188
+ self.pad_token_id = pad_token_id
189
+ self.hidden_size = hidden_size
190
+
191
+ def create_position(self, attention_mask, num_tokens_for_output_images):
192
+ position_ids = []
193
+ text_length = attention_mask.size(-1)
194
+ img_length = max(num_tokens_for_output_images)
195
+ for mask in attention_mask:
196
+ temp_l = torch.sum(mask)
197
+ temp_position = [0] * (text_length - temp_l) + list(
198
+ range(temp_l + img_length + 1)
199
+ ) # we add a time embedding into the sequence, so add one more token
200
+ position_ids.append(temp_position)
201
+ return torch.LongTensor(position_ids)
202
+
203
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
204
+ """
205
+ OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
206
+ each image sequence References: [OmniGen](https://huggingface.co/papers/2409.11340)
207
+ """
208
+ extended_mask = []
209
+ padding_images = []
210
+ text_length = attention_mask.size(-1)
211
+ img_length = max(num_tokens_for_output_images)
212
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
213
+ inx = 0
214
+ for mask in attention_mask:
215
+ temp_l = torch.sum(mask)
216
+ pad_l = text_length - temp_l
217
+
218
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
219
+
220
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
221
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
222
+
223
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
224
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
225
+
226
+ if pad_l > 0:
227
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
228
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
229
+
230
+ pad_mask = torch.ones(size=(pad_l, seq_len))
231
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
232
+
233
+ true_img_length = num_tokens_for_output_images[inx]
234
+ pad_img_length = img_length - true_img_length
235
+ if pad_img_length > 0:
236
+ temp_mask[:, -pad_img_length:] = 0
237
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
238
+ else:
239
+ temp_padding_imgs = None
240
+
241
+ extended_mask.append(temp_mask.unsqueeze(0))
242
+ padding_images.append(temp_padding_imgs)
243
+ inx += 1
244
+ return torch.cat(extended_mask, dim=0), padding_images
245
+
246
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
247
+ for b_inx in image_sizes.keys():
248
+ for start_inx, end_inx in image_sizes[b_inx]:
249
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
250
+
251
+ return attention_mask
252
+
253
+ def pad_input_ids(self, input_ids, image_sizes):
254
+ max_l = max([len(x) for x in input_ids])
255
+ padded_ids = []
256
+ attention_mask = []
257
+
258
+ for i in range(len(input_ids)):
259
+ temp_ids = input_ids[i]
260
+ temp_l = len(temp_ids)
261
+ pad_l = max_l - temp_l
262
+ if pad_l == 0:
263
+ attention_mask.append([1] * max_l)
264
+ padded_ids.append(temp_ids)
265
+ else:
266
+ attention_mask.append([0] * pad_l + [1] * temp_l)
267
+ padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
268
+
269
+ if i in image_sizes:
270
+ new_inx = []
271
+ for old_inx in image_sizes[i]:
272
+ new_inx.append([x + pad_l for x in old_inx])
273
+ image_sizes[i] = new_inx
274
+
275
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
276
+
277
+ def process_mllm_input(self, mllm_inputs, target_img_size):
278
+ num_tokens_for_output_images = []
279
+ for img_size in target_img_size:
280
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
281
+
282
+ pixel_values, image_sizes = [], {}
283
+ b_inx = 0
284
+ for x in mllm_inputs:
285
+ if x["pixel_values"] is not None:
286
+ pixel_values.extend(x["pixel_values"])
287
+ for size in x["image_sizes"]:
288
+ if b_inx not in image_sizes:
289
+ image_sizes[b_inx] = [size]
290
+ else:
291
+ image_sizes[b_inx].append(size)
292
+ b_inx += 1
293
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
294
+
295
+ input_ids = [x["input_ids"] for x in mllm_inputs]
296
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
297
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
298
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
299
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
300
+
301
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
302
+
303
+ def __call__(self, features):
304
+ mllm_inputs = [f[0] for f in features]
305
+ cfg_mllm_inputs = [f[1] for f in features]
306
+ img_cfg_mllm_input = [f[2] for f in features]
307
+ target_img_size = [f[3] for f in features]
308
+
309
+ if img_cfg_mllm_input[0] is not None:
310
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
311
+ target_img_size = target_img_size + target_img_size + target_img_size
312
+ else:
313
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
314
+ target_img_size = target_img_size + target_img_size
315
+
316
+ (
317
+ all_padded_input_ids,
318
+ all_position_ids,
319
+ all_attention_mask,
320
+ all_padding_images,
321
+ all_pixel_values,
322
+ all_image_sizes,
323
+ ) = self.process_mllm_input(mllm_inputs, target_img_size)
324
+
325
+ data = {
326
+ "input_ids": all_padded_input_ids,
327
+ "attention_mask": all_attention_mask,
328
+ "position_ids": all_position_ids,
329
+ "input_pixel_values": all_pixel_values,
330
+ "input_image_sizes": all_image_sizes,
331
+ }
332
+ return data
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pag_utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...models.attention_processor import (
22
+ Attention,
23
+ AttentionProcessor,
24
+ PAGCFGIdentitySelfAttnProcessor2_0,
25
+ PAGIdentitySelfAttnProcessor2_0,
26
+ )
27
+ from ...utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ class PAGMixin:
34
+ r"""Mixin class for [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377v1)."""
35
+
36
+ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
37
+ r"""
38
+ Set the attention processor for the PAG layers.
39
+ """
40
+ pag_attn_processors = self._pag_attn_processors
41
+ if pag_attn_processors is None:
42
+ raise ValueError(
43
+ "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters."
44
+ )
45
+
46
+ pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
47
+
48
+ if hasattr(self, "unet"):
49
+ model: nn.Module = self.unet
50
+ else:
51
+ model: nn.Module = self.transformer
52
+
53
+ def is_self_attn(module: nn.Module) -> bool:
54
+ r"""
55
+ Check if the module is self-attention module based on its name.
56
+ """
57
+ return isinstance(module, Attention) and not module.is_cross_attention
58
+
59
+ def is_fake_integral_match(layer_id, name):
60
+ layer_id = layer_id.split(".")[-1]
61
+ name = name.split(".")[-1]
62
+ return layer_id.isnumeric() and name.isnumeric() and layer_id == name
63
+
64
+ for layer_id in pag_applied_layers:
65
+ # for each PAG layer input, we find corresponding self-attention layers in the unet model
66
+ target_modules = []
67
+
68
+ for name, module in model.named_modules():
69
+ # Identify the following simple cases:
70
+ # (1) Self Attention layer existing
71
+ # (2) Whether the module name matches pag layer id even partially
72
+ # (3) Make sure it's not a fake integral match if the layer_id ends with a number
73
+ # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
74
+ if (
75
+ is_self_attn(module)
76
+ and re.search(layer_id, name) is not None
77
+ and not is_fake_integral_match(layer_id, name)
78
+ ):
79
+ logger.debug(f"Applying PAG to layer: {name}")
80
+ target_modules.append(module)
81
+
82
+ if len(target_modules) == 0:
83
+ raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
84
+
85
+ for module in target_modules:
86
+ module.processor = pag_attn_proc
87
+
88
+ def _get_pag_scale(self, t):
89
+ r"""
90
+ Get the scale factor for the perturbed attention guidance at timestep `t`.
91
+ """
92
+
93
+ if self.do_pag_adaptive_scaling:
94
+ signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
95
+ if signal_scale < 0:
96
+ signal_scale = 0
97
+ return signal_scale
98
+ else:
99
+ return self.pag_scale
100
+
101
+ def _apply_perturbed_attention_guidance(
102
+ self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
103
+ ):
104
+ r"""
105
+ Apply perturbed attention guidance to the noise prediction.
106
+
107
+ Args:
108
+ noise_pred (torch.Tensor): The noise prediction tensor.
109
+ do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
110
+ guidance_scale (float): The scale factor for the guidance term.
111
+ t (int): The current time step.
112
+ return_pred_text (bool): Whether to return the text noise prediction.
113
+
114
+ Returns:
115
+ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
116
+ perturbed attention guidance and the text noise prediction.
117
+ """
118
+ pag_scale = self._get_pag_scale(t)
119
+ if do_classifier_free_guidance:
120
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
121
+ noise_pred = (
122
+ noise_pred_uncond
123
+ + guidance_scale * (noise_pred_text - noise_pred_uncond)
124
+ + pag_scale * (noise_pred_text - noise_pred_perturb)
125
+ )
126
+ else:
127
+ noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
128
+ noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
129
+ if return_pred_text:
130
+ return noise_pred, noise_pred_text
131
+ return noise_pred
132
+
133
+ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
134
+ """
135
+ Prepares the perturbed attention guidance for the PAG model.
136
+
137
+ Args:
138
+ cond (torch.Tensor): The conditional input tensor.
139
+ uncond (torch.Tensor): The unconditional input tensor.
140
+ do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
141
+
142
+ Returns:
143
+ torch.Tensor: The prepared perturbed attention guidance tensor.
144
+ """
145
+
146
+ cond = torch.cat([cond] * 2, dim=0)
147
+
148
+ if do_classifier_free_guidance:
149
+ cond = torch.cat([uncond, cond], dim=0)
150
+ return cond
151
+
152
+ def set_pag_applied_layers(
153
+ self,
154
+ pag_applied_layers: Union[str, List[str]],
155
+ pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
156
+ PAGCFGIdentitySelfAttnProcessor2_0(),
157
+ PAGIdentitySelfAttnProcessor2_0(),
158
+ ),
159
+ ):
160
+ r"""
161
+ Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
162
+
163
+ Args:
164
+ pag_applied_layers (`str` or `List[str]`):
165
+ One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
166
+ PAG is to be applied. A few ways of expected usage are as follows:
167
+ - Single layers specified as - "blocks.{layer_index}"
168
+ - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
169
+ - Multiple layers as a block name - "mid"
170
+ - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
171
+ pag_attn_processors:
172
+ (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
173
+ PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
174
+ processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
175
+ attention processor is for PAG with CFG disabled (unconditional only).
176
+ """
177
+
178
+ if not hasattr(self, "_pag_attn_processors"):
179
+ self._pag_attn_processors = None
180
+
181
+ if not isinstance(pag_applied_layers, list):
182
+ pag_applied_layers = [pag_applied_layers]
183
+ if pag_attn_processors is not None:
184
+ if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
185
+ raise ValueError("Expected a tuple of two attention processors")
186
+
187
+ for i in range(len(pag_applied_layers)):
188
+ if not isinstance(pag_applied_layers[i], str):
189
+ raise ValueError(
190
+ f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
191
+ )
192
+
193
+ self.pag_applied_layers = pag_applied_layers
194
+ self._pag_attn_processors = pag_attn_processors
195
+
196
+ @property
197
+ def pag_scale(self) -> float:
198
+ r"""Get the scale factor for the perturbed attention guidance."""
199
+ return self._pag_scale
200
+
201
+ @property
202
+ def pag_adaptive_scale(self) -> float:
203
+ r"""Get the adaptive scale factor for the perturbed attention guidance."""
204
+ return self._pag_adaptive_scale
205
+
206
+ @property
207
+ def do_pag_adaptive_scaling(self) -> bool:
208
+ r"""Check if the adaptive scaling is enabled for the perturbed attention guidance."""
209
+ return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
210
+
211
+ @property
212
+ def do_perturbed_attention_guidance(self) -> bool:
213
+ r"""Check if the perturbed attention guidance is enabled."""
214
+ return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
215
+
216
+ @property
217
+ def pag_attn_processors(self) -> Dict[str, AttentionProcessor]:
218
+ r"""
219
+ Returns:
220
+ `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
221
+ with the key as the name of the layer.
222
+ """
223
+
224
+ if self._pag_attn_processors is None:
225
+ return {}
226
+
227
+ valid_attn_processors = {x.__class__ for x in self._pag_attn_processors}
228
+
229
+ processors = {}
230
+ # We could have iterated through the self.components.items() and checked if a component is
231
+ # `ModelMixin` subclassed but that can include a VAE too.
232
+ if hasattr(self, "unet"):
233
+ denoiser_module = self.unet
234
+ elif hasattr(self, "transformer"):
235
+ denoiser_module = self.transformer
236
+ else:
237
+ raise ValueError("No denoiser module found.")
238
+
239
+ for name, proc in denoiser_module.attn_processors.items():
240
+ if proc.__class__ in valid_attn_processors:
241
+ processors[name] = proc
242
+
243
+ return processors
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py ADDED
@@ -0,0 +1,1343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
27
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
29
+ from ...models.lora import adjust_lora_scale_text_encoder
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ is_torch_xla_available,
34
+ logging,
35
+ replace_example_docstring,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
40
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
+ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
+ from .pag_utils import PAGMixin
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> # !pip install opencv-python transformers accelerate
60
+ >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler
61
+ >>> from diffusers.utils import load_image
62
+ >>> import numpy as np
63
+ >>> import torch
64
+
65
+ >>> import cv2
66
+ >>> from PIL import Image
67
+
68
+ >>> # download an image
69
+ >>> image = load_image(
70
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
71
+ ... )
72
+ >>> image = np.array(image)
73
+
74
+ >>> # get canny image
75
+ >>> image = cv2.Canny(image, 100, 200)
76
+ >>> image = image[:, :, None]
77
+ >>> image = np.concatenate([image, image, image], axis=2)
78
+ >>> canny_image = Image.fromarray(image)
79
+
80
+ >>> # load control net and stable diffusion v1-5
81
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
82
+ >>> pipe = AutoPipelineForText2Image.from_pretrained(
83
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
84
+ ... )
85
+
86
+ >>> # speed up diffusion process with faster scheduler and memory optimization
87
+ >>> # remove following line if xformers is not installed
88
+ >>> pipe.enable_xformers_memory_efficient_attention()
89
+
90
+ >>> pipe.enable_model_cpu_offload()
91
+
92
+ >>> # generate image
93
+ >>> generator = torch.manual_seed(0)
94
+ >>> image = pipe(
95
+ ... "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting",
96
+ ... guidance_scale=7.5,
97
+ ... generator=generator,
98
+ ... image=canny_image,
99
+ ... pag_scale=10,
100
+ ... ).images[0]
101
+ ```
102
+ """
103
+
104
+
105
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
106
+ def retrieve_timesteps(
107
+ scheduler,
108
+ num_inference_steps: Optional[int] = None,
109
+ device: Optional[Union[str, torch.device]] = None,
110
+ timesteps: Optional[List[int]] = None,
111
+ sigmas: Optional[List[float]] = None,
112
+ **kwargs,
113
+ ):
114
+ r"""
115
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
116
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
117
+
118
+ Args:
119
+ scheduler (`SchedulerMixin`):
120
+ The scheduler to get timesteps from.
121
+ num_inference_steps (`int`):
122
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
123
+ must be `None`.
124
+ device (`str` or `torch.device`, *optional*):
125
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
126
+ timesteps (`List[int]`, *optional*):
127
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
128
+ `num_inference_steps` and `sigmas` must be `None`.
129
+ sigmas (`List[float]`, *optional*):
130
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
131
+ `num_inference_steps` and `timesteps` must be `None`.
132
+
133
+ Returns:
134
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
135
+ second element is the number of inference steps.
136
+ """
137
+ if timesteps is not None and sigmas is not None:
138
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
139
+ if timesteps is not None:
140
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accepts_timesteps:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" timestep schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
149
+ elif sigmas is not None:
150
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
151
+ if not accept_sigmas:
152
+ raise ValueError(
153
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
154
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
155
+ )
156
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
157
+ timesteps = scheduler.timesteps
158
+ num_inference_steps = len(timesteps)
159
+ else:
160
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
161
+ timesteps = scheduler.timesteps
162
+ return timesteps, num_inference_steps
163
+
164
+
165
+ class StableDiffusionControlNetPAGPipeline(
166
+ DiffusionPipeline,
167
+ StableDiffusionMixin,
168
+ TextualInversionLoaderMixin,
169
+ StableDiffusionLoraLoaderMixin,
170
+ IPAdapterMixin,
171
+ FromSingleFileMixin,
172
+ PAGMixin,
173
+ ):
174
+ r"""
175
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
176
+
177
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
178
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
179
+
180
+ The pipeline also inherits the following loading methods:
181
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
182
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
183
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
184
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
185
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
186
+
187
+ Args:
188
+ vae ([`AutoencoderKL`]):
189
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
190
+ text_encoder ([`~transformers.CLIPTextModel`]):
191
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
192
+ tokenizer ([`~transformers.CLIPTokenizer`]):
193
+ A `CLIPTokenizer` to tokenize text.
194
+ unet ([`UNet2DConditionModel`]):
195
+ A `UNet2DConditionModel` to denoise the encoded image latents.
196
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
197
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
198
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
199
+ additional conditioning.
200
+ scheduler ([`SchedulerMixin`]):
201
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
202
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
203
+ safety_checker ([`StableDiffusionSafetyChecker`]):
204
+ Classification module that estimates whether generated images could be considered offensive or harmful.
205
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
206
+ about a model's potential harms.
207
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
208
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
209
+ """
210
+
211
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
212
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
213
+ _exclude_from_cpu_offload = ["safety_checker"]
214
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
215
+
216
+ def __init__(
217
+ self,
218
+ vae: AutoencoderKL,
219
+ text_encoder: CLIPTextModel,
220
+ tokenizer: CLIPTokenizer,
221
+ unet: UNet2DConditionModel,
222
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
223
+ scheduler: KarrasDiffusionSchedulers,
224
+ safety_checker: StableDiffusionSafetyChecker,
225
+ feature_extractor: CLIPImageProcessor,
226
+ image_encoder: CLIPVisionModelWithProjection = None,
227
+ requires_safety_checker: bool = True,
228
+ pag_applied_layers: Union[str, List[str]] = "mid",
229
+ ):
230
+ super().__init__()
231
+
232
+ if safety_checker is None and requires_safety_checker:
233
+ logger.warning(
234
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
235
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
236
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
237
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
238
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
239
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
240
+ )
241
+
242
+ if safety_checker is not None and feature_extractor is None:
243
+ raise ValueError(
244
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
245
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
246
+ )
247
+
248
+ if isinstance(controlnet, (list, tuple)):
249
+ controlnet = MultiControlNetModel(controlnet)
250
+
251
+ self.register_modules(
252
+ vae=vae,
253
+ text_encoder=text_encoder,
254
+ tokenizer=tokenizer,
255
+ unet=unet,
256
+ controlnet=controlnet,
257
+ scheduler=scheduler,
258
+ safety_checker=safety_checker,
259
+ feature_extractor=feature_extractor,
260
+ image_encoder=image_encoder,
261
+ )
262
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
263
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
264
+ self.control_image_processor = VaeImageProcessor(
265
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
266
+ )
267
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
268
+
269
+ self.set_pag_applied_layers(pag_applied_layers)
270
+
271
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
272
+ def encode_prompt(
273
+ self,
274
+ prompt,
275
+ device,
276
+ num_images_per_prompt,
277
+ do_classifier_free_guidance,
278
+ negative_prompt=None,
279
+ prompt_embeds: Optional[torch.Tensor] = None,
280
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
281
+ lora_scale: Optional[float] = None,
282
+ clip_skip: Optional[int] = None,
283
+ ):
284
+ r"""
285
+ Encodes the prompt into text encoder hidden states.
286
+
287
+ Args:
288
+ prompt (`str` or `List[str]`, *optional*):
289
+ prompt to be encoded
290
+ device: (`torch.device`):
291
+ torch device
292
+ num_images_per_prompt (`int`):
293
+ number of images that should be generated per prompt
294
+ do_classifier_free_guidance (`bool`):
295
+ whether to use classifier free guidance or not
296
+ negative_prompt (`str` or `List[str]`, *optional*):
297
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
298
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
299
+ less than `1`).
300
+ prompt_embeds (`torch.Tensor`, *optional*):
301
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
302
+ provided, text embeddings will be generated from `prompt` input argument.
303
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
304
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
305
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
306
+ argument.
307
+ lora_scale (`float`, *optional*):
308
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
309
+ clip_skip (`int`, *optional*):
310
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
311
+ the output of the pre-final layer will be used for computing the prompt embeddings.
312
+ """
313
+ # set lora scale so that monkey patched LoRA
314
+ # function of text encoder can correctly access it
315
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
316
+ self._lora_scale = lora_scale
317
+
318
+ # dynamically adjust the LoRA scale
319
+ if not USE_PEFT_BACKEND:
320
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
321
+ else:
322
+ scale_lora_layers(self.text_encoder, lora_scale)
323
+
324
+ if prompt is not None and isinstance(prompt, str):
325
+ batch_size = 1
326
+ elif prompt is not None and isinstance(prompt, list):
327
+ batch_size = len(prompt)
328
+ else:
329
+ batch_size = prompt_embeds.shape[0]
330
+
331
+ if prompt_embeds is None:
332
+ # textual inversion: process multi-vector tokens if necessary
333
+ if isinstance(self, TextualInversionLoaderMixin):
334
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
335
+
336
+ text_inputs = self.tokenizer(
337
+ prompt,
338
+ padding="max_length",
339
+ max_length=self.tokenizer.model_max_length,
340
+ truncation=True,
341
+ return_tensors="pt",
342
+ )
343
+ text_input_ids = text_inputs.input_ids
344
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
345
+
346
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
347
+ text_input_ids, untruncated_ids
348
+ ):
349
+ removed_text = self.tokenizer.batch_decode(
350
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
351
+ )
352
+ logger.warning(
353
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
354
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
355
+ )
356
+
357
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
358
+ attention_mask = text_inputs.attention_mask.to(device)
359
+ else:
360
+ attention_mask = None
361
+
362
+ if clip_skip is None:
363
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
364
+ prompt_embeds = prompt_embeds[0]
365
+ else:
366
+ prompt_embeds = self.text_encoder(
367
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
368
+ )
369
+ # Access the `hidden_states` first, that contains a tuple of
370
+ # all the hidden states from the encoder layers. Then index into
371
+ # the tuple to access the hidden states from the desired layer.
372
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
373
+ # We also need to apply the final LayerNorm here to not mess with the
374
+ # representations. The `last_hidden_states` that we typically use for
375
+ # obtaining the final prompt representations passes through the LayerNorm
376
+ # layer.
377
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
378
+
379
+ if self.text_encoder is not None:
380
+ prompt_embeds_dtype = self.text_encoder.dtype
381
+ elif self.unet is not None:
382
+ prompt_embeds_dtype = self.unet.dtype
383
+ else:
384
+ prompt_embeds_dtype = prompt_embeds.dtype
385
+
386
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
387
+
388
+ bs_embed, seq_len, _ = prompt_embeds.shape
389
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
390
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392
+
393
+ # get unconditional embeddings for classifier free guidance
394
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
395
+ uncond_tokens: List[str]
396
+ if negative_prompt is None:
397
+ uncond_tokens = [""] * batch_size
398
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
399
+ raise TypeError(
400
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
401
+ f" {type(prompt)}."
402
+ )
403
+ elif isinstance(negative_prompt, str):
404
+ uncond_tokens = [negative_prompt]
405
+ elif batch_size != len(negative_prompt):
406
+ raise ValueError(
407
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
408
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
409
+ " the batch size of `prompt`."
410
+ )
411
+ else:
412
+ uncond_tokens = negative_prompt
413
+
414
+ # textual inversion: process multi-vector tokens if necessary
415
+ if isinstance(self, TextualInversionLoaderMixin):
416
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
417
+
418
+ max_length = prompt_embeds.shape[1]
419
+ uncond_input = self.tokenizer(
420
+ uncond_tokens,
421
+ padding="max_length",
422
+ max_length=max_length,
423
+ truncation=True,
424
+ return_tensors="pt",
425
+ )
426
+
427
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
428
+ attention_mask = uncond_input.attention_mask.to(device)
429
+ else:
430
+ attention_mask = None
431
+
432
+ negative_prompt_embeds = self.text_encoder(
433
+ uncond_input.input_ids.to(device),
434
+ attention_mask=attention_mask,
435
+ )
436
+ negative_prompt_embeds = negative_prompt_embeds[0]
437
+
438
+ if do_classifier_free_guidance:
439
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
440
+ seq_len = negative_prompt_embeds.shape[1]
441
+
442
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
443
+
444
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
445
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
446
+
447
+ if self.text_encoder is not None:
448
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
449
+ # Retrieve the original scale by scaling back the LoRA layers
450
+ unscale_lora_layers(self.text_encoder, lora_scale)
451
+
452
+ return prompt_embeds, negative_prompt_embeds
453
+
454
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
455
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
456
+ dtype = next(self.image_encoder.parameters()).dtype
457
+
458
+ if not isinstance(image, torch.Tensor):
459
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
460
+
461
+ image = image.to(device=device, dtype=dtype)
462
+ if output_hidden_states:
463
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
464
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
465
+ uncond_image_enc_hidden_states = self.image_encoder(
466
+ torch.zeros_like(image), output_hidden_states=True
467
+ ).hidden_states[-2]
468
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
469
+ num_images_per_prompt, dim=0
470
+ )
471
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
472
+ else:
473
+ image_embeds = self.image_encoder(image).image_embeds
474
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
475
+ uncond_image_embeds = torch.zeros_like(image_embeds)
476
+
477
+ return image_embeds, uncond_image_embeds
478
+
479
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
480
+ def prepare_ip_adapter_image_embeds(
481
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
482
+ ):
483
+ image_embeds = []
484
+ if do_classifier_free_guidance:
485
+ negative_image_embeds = []
486
+ if ip_adapter_image_embeds is None:
487
+ if not isinstance(ip_adapter_image, list):
488
+ ip_adapter_image = [ip_adapter_image]
489
+
490
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
491
+ raise ValueError(
492
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
493
+ )
494
+
495
+ for single_ip_adapter_image, image_proj_layer in zip(
496
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
497
+ ):
498
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
499
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
500
+ single_ip_adapter_image, device, 1, output_hidden_state
501
+ )
502
+
503
+ image_embeds.append(single_image_embeds[None, :])
504
+ if do_classifier_free_guidance:
505
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
506
+ else:
507
+ for single_image_embeds in ip_adapter_image_embeds:
508
+ if do_classifier_free_guidance:
509
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
510
+ negative_image_embeds.append(single_negative_image_embeds)
511
+ image_embeds.append(single_image_embeds)
512
+
513
+ ip_adapter_image_embeds = []
514
+ for i, single_image_embeds in enumerate(image_embeds):
515
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
516
+ if do_classifier_free_guidance:
517
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
518
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
519
+
520
+ single_image_embeds = single_image_embeds.to(device=device)
521
+ ip_adapter_image_embeds.append(single_image_embeds)
522
+
523
+ return ip_adapter_image_embeds
524
+
525
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
526
+ def run_safety_checker(self, image, device, dtype):
527
+ if self.safety_checker is None:
528
+ has_nsfw_concept = None
529
+ else:
530
+ if torch.is_tensor(image):
531
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
532
+ else:
533
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
534
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
535
+ image, has_nsfw_concept = self.safety_checker(
536
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
537
+ )
538
+ return image, has_nsfw_concept
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
541
+ def prepare_extra_step_kwargs(self, generator, eta):
542
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
543
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
544
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
545
+ # and should be between [0, 1]
546
+
547
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
548
+ extra_step_kwargs = {}
549
+ if accepts_eta:
550
+ extra_step_kwargs["eta"] = eta
551
+
552
+ # check if the scheduler accepts generator
553
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
554
+ if accepts_generator:
555
+ extra_step_kwargs["generator"] = generator
556
+ return extra_step_kwargs
557
+
558
+ def check_inputs(
559
+ self,
560
+ prompt,
561
+ image,
562
+ negative_prompt=None,
563
+ prompt_embeds=None,
564
+ negative_prompt_embeds=None,
565
+ ip_adapter_image=None,
566
+ ip_adapter_image_embeds=None,
567
+ controlnet_conditioning_scale=1.0,
568
+ control_guidance_start=0.0,
569
+ control_guidance_end=1.0,
570
+ callback_on_step_end_tensor_inputs=None,
571
+ ):
572
+ if callback_on_step_end_tensor_inputs is not None and not all(
573
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
574
+ ):
575
+ raise ValueError(
576
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
577
+ )
578
+
579
+ if prompt is not None and prompt_embeds is not None:
580
+ raise ValueError(
581
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
582
+ " only forward one of the two."
583
+ )
584
+ elif prompt is None and prompt_embeds is None:
585
+ raise ValueError(
586
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
587
+ )
588
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
589
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
590
+
591
+ if negative_prompt is not None and negative_prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
594
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
595
+ )
596
+
597
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
598
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
599
+ raise ValueError(
600
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
601
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
602
+ f" {negative_prompt_embeds.shape}."
603
+ )
604
+
605
+ # Check `image`
606
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
607
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
608
+ )
609
+ if (
610
+ isinstance(self.controlnet, ControlNetModel)
611
+ or is_compiled
612
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
613
+ ):
614
+ self.check_image(image, prompt, prompt_embeds)
615
+ elif (
616
+ isinstance(self.controlnet, MultiControlNetModel)
617
+ or is_compiled
618
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
619
+ ):
620
+ if not isinstance(image, list):
621
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
622
+
623
+ # When `image` is a nested list:
624
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
625
+ elif any(isinstance(i, list) for i in image):
626
+ transposed_image = [list(t) for t in zip(*image)]
627
+ if len(transposed_image) != len(self.controlnet.nets):
628
+ raise ValueError(
629
+ f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
630
+ )
631
+ for image_ in transposed_image:
632
+ self.check_image(image_, prompt, prompt_embeds)
633
+ elif len(image) != len(self.controlnet.nets):
634
+ raise ValueError(
635
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
636
+ )
637
+ else:
638
+ for image_ in image:
639
+ self.check_image(image_, prompt, prompt_embeds)
640
+ else:
641
+ assert False
642
+
643
+ # Check `controlnet_conditioning_scale`
644
+ if (
645
+ isinstance(self.controlnet, ControlNetModel)
646
+ or is_compiled
647
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
648
+ ):
649
+ if not isinstance(controlnet_conditioning_scale, float):
650
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
651
+ elif (
652
+ isinstance(self.controlnet, MultiControlNetModel)
653
+ or is_compiled
654
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
655
+ ):
656
+ if isinstance(controlnet_conditioning_scale, list):
657
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
658
+ raise ValueError(
659
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
660
+ "The conditioning scale must be fixed across the batch."
661
+ )
662
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
663
+ self.controlnet.nets
664
+ ):
665
+ raise ValueError(
666
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
667
+ " the same length as the number of controlnets"
668
+ )
669
+ else:
670
+ assert False
671
+
672
+ if not isinstance(control_guidance_start, (tuple, list)):
673
+ control_guidance_start = [control_guidance_start]
674
+
675
+ if not isinstance(control_guidance_end, (tuple, list)):
676
+ control_guidance_end = [control_guidance_end]
677
+
678
+ if len(control_guidance_start) != len(control_guidance_end):
679
+ raise ValueError(
680
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
681
+ )
682
+
683
+ if isinstance(self.controlnet, MultiControlNetModel):
684
+ if len(control_guidance_start) != len(self.controlnet.nets):
685
+ raise ValueError(
686
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
687
+ )
688
+
689
+ for start, end in zip(control_guidance_start, control_guidance_end):
690
+ if start >= end:
691
+ raise ValueError(
692
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
693
+ )
694
+ if start < 0.0:
695
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
696
+ if end > 1.0:
697
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
698
+
699
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
700
+ raise ValueError(
701
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
702
+ )
703
+
704
+ if ip_adapter_image_embeds is not None:
705
+ if not isinstance(ip_adapter_image_embeds, list):
706
+ raise ValueError(
707
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
708
+ )
709
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
710
+ raise ValueError(
711
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
712
+ )
713
+
714
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
715
+ def check_image(self, image, prompt, prompt_embeds):
716
+ image_is_pil = isinstance(image, PIL.Image.Image)
717
+ image_is_tensor = isinstance(image, torch.Tensor)
718
+ image_is_np = isinstance(image, np.ndarray)
719
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
720
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
721
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
722
+
723
+ if (
724
+ not image_is_pil
725
+ and not image_is_tensor
726
+ and not image_is_np
727
+ and not image_is_pil_list
728
+ and not image_is_tensor_list
729
+ and not image_is_np_list
730
+ ):
731
+ raise TypeError(
732
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
733
+ )
734
+
735
+ if image_is_pil:
736
+ image_batch_size = 1
737
+ else:
738
+ image_batch_size = len(image)
739
+
740
+ if prompt is not None and isinstance(prompt, str):
741
+ prompt_batch_size = 1
742
+ elif prompt is not None and isinstance(prompt, list):
743
+ prompt_batch_size = len(prompt)
744
+ elif prompt_embeds is not None:
745
+ prompt_batch_size = prompt_embeds.shape[0]
746
+
747
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
748
+ raise ValueError(
749
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
750
+ )
751
+
752
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
753
+ def prepare_image(
754
+ self,
755
+ image,
756
+ width,
757
+ height,
758
+ batch_size,
759
+ num_images_per_prompt,
760
+ device,
761
+ dtype,
762
+ do_classifier_free_guidance=False,
763
+ guess_mode=False,
764
+ ):
765
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
766
+ image_batch_size = image.shape[0]
767
+
768
+ if image_batch_size == 1:
769
+ repeat_by = batch_size
770
+ else:
771
+ # image batch size is the same as prompt batch size
772
+ repeat_by = num_images_per_prompt
773
+
774
+ image = image.repeat_interleave(repeat_by, dim=0)
775
+
776
+ image = image.to(device=device, dtype=dtype)
777
+
778
+ if do_classifier_free_guidance and not guess_mode:
779
+ image = torch.cat([image] * 2)
780
+
781
+ return image
782
+
783
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
784
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
785
+ shape = (
786
+ batch_size,
787
+ num_channels_latents,
788
+ int(height) // self.vae_scale_factor,
789
+ int(width) // self.vae_scale_factor,
790
+ )
791
+ if isinstance(generator, list) and len(generator) != batch_size:
792
+ raise ValueError(
793
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
794
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
795
+ )
796
+
797
+ if latents is None:
798
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
799
+ else:
800
+ latents = latents.to(device)
801
+
802
+ # scale the initial noise by the standard deviation required by the scheduler
803
+ latents = latents * self.scheduler.init_noise_sigma
804
+ return latents
805
+
806
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
807
+ def get_guidance_scale_embedding(
808
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
809
+ ) -> torch.Tensor:
810
+ """
811
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
812
+
813
+ Args:
814
+ w (`torch.Tensor`):
815
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
816
+ embedding_dim (`int`, *optional*, defaults to 512):
817
+ Dimension of the embeddings to generate.
818
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
819
+ Data type of the generated embeddings.
820
+
821
+ Returns:
822
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
823
+ """
824
+ assert len(w.shape) == 1
825
+ w = w * 1000.0
826
+
827
+ half_dim = embedding_dim // 2
828
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
829
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
830
+ emb = w.to(dtype)[:, None] * emb[None, :]
831
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
832
+ if embedding_dim % 2 == 1: # zero pad
833
+ emb = torch.nn.functional.pad(emb, (0, 1))
834
+ assert emb.shape == (w.shape[0], embedding_dim)
835
+ return emb
836
+
837
+ @property
838
+ def guidance_scale(self):
839
+ return self._guidance_scale
840
+
841
+ @property
842
+ def clip_skip(self):
843
+ return self._clip_skip
844
+
845
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
846
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
847
+ # corresponds to doing no classifier free guidance.
848
+ @property
849
+ def do_classifier_free_guidance(self):
850
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
851
+
852
+ @property
853
+ def cross_attention_kwargs(self):
854
+ return self._cross_attention_kwargs
855
+
856
+ @property
857
+ def num_timesteps(self):
858
+ return self._num_timesteps
859
+
860
+ @torch.no_grad()
861
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
862
+ def __call__(
863
+ self,
864
+ prompt: Union[str, List[str]] = None,
865
+ image: PipelineImageInput = None,
866
+ height: Optional[int] = None,
867
+ width: Optional[int] = None,
868
+ num_inference_steps: int = 50,
869
+ timesteps: List[int] = None,
870
+ sigmas: List[float] = None,
871
+ guidance_scale: float = 7.5,
872
+ negative_prompt: Optional[Union[str, List[str]]] = None,
873
+ num_images_per_prompt: Optional[int] = 1,
874
+ eta: float = 0.0,
875
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
876
+ latents: Optional[torch.Tensor] = None,
877
+ prompt_embeds: Optional[torch.Tensor] = None,
878
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
879
+ ip_adapter_image: Optional[PipelineImageInput] = None,
880
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
881
+ output_type: Optional[str] = "pil",
882
+ return_dict: bool = True,
883
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
884
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
885
+ guess_mode: bool = False,
886
+ control_guidance_start: Union[float, List[float]] = 0.0,
887
+ control_guidance_end: Union[float, List[float]] = 1.0,
888
+ clip_skip: Optional[int] = None,
889
+ callback_on_step_end: Optional[
890
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
891
+ ] = None,
892
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
893
+ pag_scale: float = 3.0,
894
+ pag_adaptive_scale: float = 0.0,
895
+ ):
896
+ r"""
897
+ The call function to the pipeline for generation.
898
+
899
+ Args:
900
+ prompt (`str` or `List[str]`, *optional*):
901
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
902
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
903
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
904
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
905
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
906
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
907
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
908
+ images must be passed as a list such that each element of the list can be correctly batched for input
909
+ to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
910
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
911
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
912
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
913
+ The height in pixels of the generated image.
914
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
915
+ The width in pixels of the generated image.
916
+ num_inference_steps (`int`, *optional*, defaults to 50):
917
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
918
+ expense of slower inference.
919
+ timesteps (`List[int]`, *optional*):
920
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
921
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
922
+ passed will be used. Must be in descending order.
923
+ sigmas (`List[float]`, *optional*):
924
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
925
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
926
+ will be used.
927
+ guidance_scale (`float`, *optional*, defaults to 7.5):
928
+ A higher guidance scale value encourages the model to generate images closely linked to the text
929
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
930
+ negative_prompt (`str` or `List[str]`, *optional*):
931
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
932
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
933
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
934
+ The number of images to generate per prompt.
935
+ eta (`float`, *optional*, defaults to 0.0):
936
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
937
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
938
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
939
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
940
+ generation deterministic.
941
+ latents (`torch.Tensor`, *optional*):
942
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
943
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
944
+ tensor is generated by sampling using the supplied random `generator`.
945
+ prompt_embeds (`torch.Tensor`, *optional*):
946
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
947
+ provided, text embeddings are generated from the `prompt` input argument.
948
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
949
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
950
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
951
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
952
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
953
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
954
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
955
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
956
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
957
+ output_type (`str`, *optional*, defaults to `"pil"`):
958
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
959
+ return_dict (`bool`, *optional*, defaults to `True`):
960
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
961
+ plain tuple.
962
+ cross_attention_kwargs (`dict`, *optional*):
963
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
964
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
965
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
966
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
967
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
968
+ the corresponding scale as a list.
969
+ guess_mode (`bool`, *optional*, defaults to `False`):
970
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
971
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
972
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
973
+ The percentage of total steps at which the ControlNet starts applying.
974
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
975
+ The percentage of total steps at which the ControlNet stops applying.
976
+ clip_skip (`int`, *optional*):
977
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
978
+ the output of the pre-final layer will be used for computing the prompt embeddings.
979
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
980
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
981
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
982
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
983
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
984
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
985
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
986
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
987
+ `._callback_tensor_inputs` attribute of your pipeline class.
988
+ pag_scale (`float`, *optional*, defaults to 3.0):
989
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
990
+ guidance will not be used.
991
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
992
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
993
+ used.
994
+
995
+ Examples:
996
+
997
+ Returns:
998
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
999
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1000
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1001
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1002
+ "not-safe-for-work" (nsfw) content.
1003
+ """
1004
+
1005
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1006
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1007
+
1008
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1009
+
1010
+ # align format for control guidance
1011
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1012
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1013
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1014
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1015
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1016
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1017
+ control_guidance_start, control_guidance_end = (
1018
+ mult * [control_guidance_start],
1019
+ mult * [control_guidance_end],
1020
+ )
1021
+
1022
+ # 1. Check inputs. Raise error if not correct
1023
+ self.check_inputs(
1024
+ prompt,
1025
+ image,
1026
+ negative_prompt,
1027
+ prompt_embeds,
1028
+ negative_prompt_embeds,
1029
+ ip_adapter_image,
1030
+ ip_adapter_image_embeds,
1031
+ controlnet_conditioning_scale,
1032
+ control_guidance_start,
1033
+ control_guidance_end,
1034
+ callback_on_step_end_tensor_inputs,
1035
+ )
1036
+
1037
+ self._guidance_scale = guidance_scale
1038
+ self._clip_skip = clip_skip
1039
+ self._cross_attention_kwargs = cross_attention_kwargs
1040
+ self._pag_scale = pag_scale
1041
+ self._pag_adaptive_scale = pag_adaptive_scale
1042
+
1043
+ # 2. Define call parameters
1044
+ if prompt is not None and isinstance(prompt, str):
1045
+ batch_size = 1
1046
+ elif prompt is not None and isinstance(prompt, list):
1047
+ batch_size = len(prompt)
1048
+ else:
1049
+ batch_size = prompt_embeds.shape[0]
1050
+
1051
+ device = self._execution_device
1052
+
1053
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1054
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1055
+
1056
+ global_pool_conditions = (
1057
+ controlnet.config.global_pool_conditions
1058
+ if isinstance(controlnet, ControlNetModel)
1059
+ else controlnet.nets[0].config.global_pool_conditions
1060
+ )
1061
+ guess_mode = guess_mode or global_pool_conditions
1062
+
1063
+ # 3. Encode input prompt
1064
+ text_encoder_lora_scale = (
1065
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1066
+ )
1067
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1068
+ prompt,
1069
+ device,
1070
+ num_images_per_prompt,
1071
+ self.do_classifier_free_guidance,
1072
+ negative_prompt,
1073
+ prompt_embeds=prompt_embeds,
1074
+ negative_prompt_embeds=negative_prompt_embeds,
1075
+ lora_scale=text_encoder_lora_scale,
1076
+ clip_skip=self.clip_skip,
1077
+ )
1078
+ # For classifier free guidance, we need to do two forward passes.
1079
+ # Here we concatenate the unconditional and text embeddings into a single batch
1080
+ # to avoid doing two forward passes
1081
+ if self.do_perturbed_attention_guidance:
1082
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
1083
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
1084
+ )
1085
+ elif self.do_classifier_free_guidance:
1086
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1087
+
1088
+ # 4. Prepare image
1089
+ if isinstance(controlnet, ControlNetModel):
1090
+ image = self.prepare_image(
1091
+ image=image,
1092
+ width=width,
1093
+ height=height,
1094
+ batch_size=batch_size * num_images_per_prompt,
1095
+ num_images_per_prompt=num_images_per_prompt,
1096
+ device=device,
1097
+ dtype=controlnet.dtype,
1098
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1099
+ guess_mode=guess_mode,
1100
+ )
1101
+ height, width = image.shape[-2:]
1102
+ elif isinstance(controlnet, MultiControlNetModel):
1103
+ images = []
1104
+
1105
+ # Nested lists as ControlNet condition
1106
+ if isinstance(image[0], list):
1107
+ # Transpose the nested image list
1108
+ image = [list(t) for t in zip(*image)]
1109
+
1110
+ for image_ in image:
1111
+ image_ = self.prepare_image(
1112
+ image=image_,
1113
+ width=width,
1114
+ height=height,
1115
+ batch_size=batch_size * num_images_per_prompt,
1116
+ num_images_per_prompt=num_images_per_prompt,
1117
+ device=device,
1118
+ dtype=controlnet.dtype,
1119
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1120
+ guess_mode=guess_mode,
1121
+ )
1122
+
1123
+ images.append(image_)
1124
+
1125
+ image = images
1126
+ height, width = image[0].shape[-2:]
1127
+ else:
1128
+ assert False
1129
+
1130
+ # 5. Prepare timesteps
1131
+ timesteps, num_inference_steps = retrieve_timesteps(
1132
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1133
+ )
1134
+ self._num_timesteps = len(timesteps)
1135
+
1136
+ # 6. Prepare latent variables
1137
+ num_channels_latents = self.unet.config.in_channels
1138
+ latents = self.prepare_latents(
1139
+ batch_size * num_images_per_prompt,
1140
+ num_channels_latents,
1141
+ height,
1142
+ width,
1143
+ prompt_embeds.dtype,
1144
+ device,
1145
+ generator,
1146
+ latents,
1147
+ )
1148
+
1149
+ # 6.5 Optionally get Guidance Scale Embedding
1150
+ timestep_cond = None
1151
+ if self.unet.config.time_cond_proj_dim is not None:
1152
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1153
+ timestep_cond = self.get_guidance_scale_embedding(
1154
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1155
+ ).to(device=device, dtype=latents.dtype)
1156
+
1157
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1158
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1159
+
1160
+ # 7.1 Add image embeds for IP-Adapter
1161
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1162
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1163
+ ip_adapter_image,
1164
+ ip_adapter_image_embeds,
1165
+ device,
1166
+ batch_size * num_images_per_prompt,
1167
+ self.do_classifier_free_guidance,
1168
+ )
1169
+ for i, image_embeds in enumerate(ip_adapter_image_embeds):
1170
+ negative_image_embeds = None
1171
+ if self.do_classifier_free_guidance:
1172
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
1173
+
1174
+ if self.do_perturbed_attention_guidance:
1175
+ image_embeds = self._prepare_perturbed_attention_guidance(
1176
+ image_embeds, negative_image_embeds, self.do_classifier_free_guidance
1177
+ )
1178
+ elif self.do_classifier_free_guidance:
1179
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
1180
+ image_embeds = image_embeds.to(device)
1181
+ ip_adapter_image_embeds[i] = image_embeds
1182
+
1183
+ added_cond_kwargs = (
1184
+ {"image_embeds": ip_adapter_image_embeds}
1185
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1186
+ else None
1187
+ )
1188
+
1189
+ controlnet_prompt_embeds = prompt_embeds
1190
+
1191
+ # 7.2 Create tensor stating which controlnets to keep
1192
+ controlnet_keep = []
1193
+ for i in range(len(timesteps)):
1194
+ keeps = [
1195
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1196
+ for s, e in zip(control_guidance_start, control_guidance_end)
1197
+ ]
1198
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1199
+
1200
+ images = image if isinstance(image, list) else [image]
1201
+ for i, single_image in enumerate(images):
1202
+ if self.do_classifier_free_guidance:
1203
+ single_image = single_image.chunk(2)[0]
1204
+
1205
+ if self.do_perturbed_attention_guidance:
1206
+ single_image = self._prepare_perturbed_attention_guidance(
1207
+ single_image, single_image, self.do_classifier_free_guidance
1208
+ )
1209
+ elif self.do_classifier_free_guidance:
1210
+ single_image = torch.cat([single_image] * 2)
1211
+ single_image = single_image.to(device)
1212
+ images[i] = single_image
1213
+
1214
+ image = images if isinstance(image, list) else images[0]
1215
+
1216
+ # 8. Denoising loop
1217
+ if self.do_perturbed_attention_guidance:
1218
+ original_attn_proc = self.unet.attn_processors
1219
+ self._set_pag_attn_processor(
1220
+ pag_applied_layers=self.pag_applied_layers,
1221
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1222
+ )
1223
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1224
+ is_unet_compiled = is_compiled_module(self.unet)
1225
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1226
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1227
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1228
+ for i, t in enumerate(timesteps):
1229
+ # Relevant thread:
1230
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1231
+ if (
1232
+ torch.cuda.is_available()
1233
+ and (is_unet_compiled and is_controlnet_compiled)
1234
+ and is_torch_higher_equal_2_1
1235
+ ):
1236
+ torch._inductor.cudagraph_mark_step_begin()
1237
+ # expand the latents if we are doing classifier free guidance
1238
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
1239
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1240
+
1241
+ # controlnet(s) inference
1242
+ control_model_input = latent_model_input
1243
+
1244
+ if isinstance(controlnet_keep[i], list):
1245
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1246
+ else:
1247
+ controlnet_cond_scale = controlnet_conditioning_scale
1248
+ if isinstance(controlnet_cond_scale, list):
1249
+ controlnet_cond_scale = controlnet_cond_scale[0]
1250
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1251
+
1252
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1253
+ control_model_input,
1254
+ t,
1255
+ encoder_hidden_states=controlnet_prompt_embeds,
1256
+ controlnet_cond=image,
1257
+ conditioning_scale=cond_scale,
1258
+ guess_mode=guess_mode,
1259
+ return_dict=False,
1260
+ )
1261
+
1262
+ if guess_mode and self.do_classifier_free_guidance:
1263
+ # Inferred ControlNet only for the conditional batch.
1264
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1265
+ # add 0 to the unconditional batch to keep it unchanged.
1266
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1267
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1268
+
1269
+ # predict the noise residual
1270
+ noise_pred = self.unet(
1271
+ latent_model_input,
1272
+ t,
1273
+ encoder_hidden_states=prompt_embeds,
1274
+ timestep_cond=timestep_cond,
1275
+ cross_attention_kwargs=self.cross_attention_kwargs,
1276
+ down_block_additional_residuals=down_block_res_samples,
1277
+ mid_block_additional_residual=mid_block_res_sample,
1278
+ added_cond_kwargs=added_cond_kwargs,
1279
+ return_dict=False,
1280
+ )[0]
1281
+
1282
+ # perform guidance
1283
+ if self.do_perturbed_attention_guidance:
1284
+ noise_pred = self._apply_perturbed_attention_guidance(
1285
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1286
+ )
1287
+ elif self.do_classifier_free_guidance:
1288
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1289
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1290
+
1291
+ # compute the previous noisy sample x_t -> x_t-1
1292
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1293
+
1294
+ if callback_on_step_end is not None:
1295
+ callback_kwargs = {}
1296
+ for k in callback_on_step_end_tensor_inputs:
1297
+ callback_kwargs[k] = locals()[k]
1298
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1299
+
1300
+ latents = callback_outputs.pop("latents", latents)
1301
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1302
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1303
+
1304
+ # call the callback, if provided
1305
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1306
+ progress_bar.update()
1307
+
1308
+ if XLA_AVAILABLE:
1309
+ xm.mark_step()
1310
+
1311
+ # If we do sequential model offloading, let's offload unet and controlnet
1312
+ # manually for max memory savings
1313
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1314
+ self.unet.to("cpu")
1315
+ self.controlnet.to("cpu")
1316
+ empty_device_cache()
1317
+
1318
+ if not output_type == "latent":
1319
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1320
+ 0
1321
+ ]
1322
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1323
+ else:
1324
+ image = latents
1325
+ has_nsfw_concept = None
1326
+
1327
+ if has_nsfw_concept is None:
1328
+ do_denormalize = [True] * image.shape[0]
1329
+ else:
1330
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1331
+
1332
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1333
+
1334
+ # Offload all models
1335
+ self.maybe_free_model_hooks()
1336
+
1337
+ if self.do_perturbed_attention_guidance:
1338
+ self.unet.set_attn_processor(original_attn_proc)
1339
+
1340
+ if not return_dict:
1341
+ return (image, has_nsfw_concept)
1342
+
1343
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
16
+
17
+ import inspect
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
25
+
26
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
28
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
29
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
30
+ from ...models.lora import adjust_lora_scale_text_encoder
31
+ from ...schedulers import KarrasDiffusionSchedulers
32
+ from ...utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
41
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
42
+ from ..stable_diffusion import StableDiffusionPipelineOutput
43
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
+ from .pag_utils import PAGMixin
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> # !pip install transformers accelerate
61
+ >>> import cv2
62
+ >>> from diffusers import AutoPipelineForInpainting, ControlNetModel, DDIMScheduler
63
+ >>> from diffusers.utils import load_image
64
+ >>> import numpy as np
65
+ >>> from PIL import Image
66
+ >>> import torch
67
+
68
+ >>> init_image = load_image(
69
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
70
+ ... )
71
+ >>> init_image = init_image.resize((512, 512))
72
+
73
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
74
+
75
+ >>> mask_image = load_image(
76
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
77
+ ... )
78
+ >>> mask_image = mask_image.resize((512, 512))
79
+
80
+
81
+ >>> def make_canny_condition(image):
82
+ ... image = np.array(image)
83
+ ... image = cv2.Canny(image, 100, 200)
84
+ ... image = image[:, :, None]
85
+ ... image = np.concatenate([image, image, image], axis=2)
86
+ ... image = Image.fromarray(image)
87
+ ... return image
88
+
89
+
90
+ >>> control_image = make_canny_condition(init_image)
91
+
92
+ >>> controlnet = ControlNetModel.from_pretrained(
93
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
94
+ ... )
95
+ >>> pipe = AutoPipelineForInpainting.from_pretrained(
96
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
97
+ ... )
98
+
99
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
100
+ >>> pipe.enable_model_cpu_offload()
101
+
102
+ >>> # generate image
103
+ >>> image = pipe(
104
+ ... "a handsome man with ray-ban sunglasses",
105
+ ... num_inference_steps=20,
106
+ ... generator=generator,
107
+ ... eta=1.0,
108
+ ... image=init_image,
109
+ ... mask_image=mask_image,
110
+ ... control_image=control_image,
111
+ ... pag_scale=0.3,
112
+ ... ).images[0]
113
+ ```
114
+ """
115
+
116
+
117
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
118
+ def retrieve_latents(
119
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
120
+ ):
121
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
122
+ return encoder_output.latent_dist.sample(generator)
123
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
124
+ return encoder_output.latent_dist.mode()
125
+ elif hasattr(encoder_output, "latents"):
126
+ return encoder_output.latents
127
+ else:
128
+ raise AttributeError("Could not access latents of provided encoder_output")
129
+
130
+
131
+ class StableDiffusionControlNetPAGInpaintPipeline(
132
+ DiffusionPipeline,
133
+ StableDiffusionMixin,
134
+ TextualInversionLoaderMixin,
135
+ StableDiffusionLoraLoaderMixin,
136
+ IPAdapterMixin,
137
+ FromSingleFileMixin,
138
+ PAGMixin,
139
+ ):
140
+ r"""
141
+ Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
142
+
143
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
144
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
145
+
146
+ The pipeline also inherits the following loading methods:
147
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
148
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
149
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
150
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
151
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
152
+
153
+ <Tip>
154
+
155
+ This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
156
+ ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
157
+ default text-to-image Stable Diffusion checkpoints
158
+ ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
159
+ Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
160
+ [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
161
+
162
+ </Tip>
163
+
164
+ Args:
165
+ vae ([`AutoencoderKL`]):
166
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
167
+ text_encoder ([`~transformers.CLIPTextModel`]):
168
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
169
+ tokenizer ([`~transformers.CLIPTokenizer`]):
170
+ A `CLIPTokenizer` to tokenize text.
171
+ unet ([`UNet2DConditionModel`]):
172
+ A `UNet2DConditionModel` to denoise the encoded image latents.
173
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
174
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
175
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
176
+ additional conditioning.
177
+ scheduler ([`SchedulerMixin`]):
178
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
179
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
180
+ safety_checker ([`StableDiffusionSafetyChecker`]):
181
+ Classification module that estimates whether generated images could be considered offensive or harmful.
182
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
183
+ about a model's potential harms.
184
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
185
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
186
+ """
187
+
188
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
189
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
190
+ _exclude_from_cpu_offload = ["safety_checker"]
191
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
192
+
193
+ def __init__(
194
+ self,
195
+ vae: AutoencoderKL,
196
+ text_encoder: CLIPTextModel,
197
+ tokenizer: CLIPTokenizer,
198
+ unet: UNet2DConditionModel,
199
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
200
+ scheduler: KarrasDiffusionSchedulers,
201
+ safety_checker: StableDiffusionSafetyChecker,
202
+ feature_extractor: CLIPImageProcessor,
203
+ image_encoder: CLIPVisionModelWithProjection = None,
204
+ requires_safety_checker: bool = True,
205
+ pag_applied_layers: Union[str, List[str]] = "mid",
206
+ ):
207
+ super().__init__()
208
+
209
+ if safety_checker is None and requires_safety_checker:
210
+ logger.warning(
211
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
212
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
213
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
214
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
215
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
216
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
217
+ )
218
+
219
+ if safety_checker is not None and feature_extractor is None:
220
+ raise ValueError(
221
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
222
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
223
+ )
224
+
225
+ if isinstance(controlnet, (list, tuple)):
226
+ controlnet = MultiControlNetModel(controlnet)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ tokenizer=tokenizer,
232
+ unet=unet,
233
+ controlnet=controlnet,
234
+ scheduler=scheduler,
235
+ safety_checker=safety_checker,
236
+ feature_extractor=feature_extractor,
237
+ image_encoder=image_encoder,
238
+ )
239
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
240
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
241
+ self.mask_processor = VaeImageProcessor(
242
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
243
+ )
244
+ self.control_image_processor = VaeImageProcessor(
245
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
246
+ )
247
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
248
+ self.set_pag_applied_layers(pag_applied_layers)
249
+
250
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
251
+ def encode_prompt(
252
+ self,
253
+ prompt,
254
+ device,
255
+ num_images_per_prompt,
256
+ do_classifier_free_guidance,
257
+ negative_prompt=None,
258
+ prompt_embeds: Optional[torch.Tensor] = None,
259
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
260
+ lora_scale: Optional[float] = None,
261
+ clip_skip: Optional[int] = None,
262
+ ):
263
+ r"""
264
+ Encodes the prompt into text encoder hidden states.
265
+
266
+ Args:
267
+ prompt (`str` or `List[str]`, *optional*):
268
+ prompt to be encoded
269
+ device: (`torch.device`):
270
+ torch device
271
+ num_images_per_prompt (`int`):
272
+ number of images that should be generated per prompt
273
+ do_classifier_free_guidance (`bool`):
274
+ whether to use classifier free guidance or not
275
+ negative_prompt (`str` or `List[str]`, *optional*):
276
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
277
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
278
+ less than `1`).
279
+ prompt_embeds (`torch.Tensor`, *optional*):
280
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
281
+ provided, text embeddings will be generated from `prompt` input argument.
282
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
283
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
284
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
285
+ argument.
286
+ lora_scale (`float`, *optional*):
287
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
288
+ clip_skip (`int`, *optional*):
289
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
290
+ the output of the pre-final layer will be used for computing the prompt embeddings.
291
+ """
292
+ # set lora scale so that monkey patched LoRA
293
+ # function of text encoder can correctly access it
294
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
295
+ self._lora_scale = lora_scale
296
+
297
+ # dynamically adjust the LoRA scale
298
+ if not USE_PEFT_BACKEND:
299
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
300
+ else:
301
+ scale_lora_layers(self.text_encoder, lora_scale)
302
+
303
+ if prompt is not None and isinstance(prompt, str):
304
+ batch_size = 1
305
+ elif prompt is not None and isinstance(prompt, list):
306
+ batch_size = len(prompt)
307
+ else:
308
+ batch_size = prompt_embeds.shape[0]
309
+
310
+ if prompt_embeds is None:
311
+ # textual inversion: process multi-vector tokens if necessary
312
+ if isinstance(self, TextualInversionLoaderMixin):
313
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
314
+
315
+ text_inputs = self.tokenizer(
316
+ prompt,
317
+ padding="max_length",
318
+ max_length=self.tokenizer.model_max_length,
319
+ truncation=True,
320
+ return_tensors="pt",
321
+ )
322
+ text_input_ids = text_inputs.input_ids
323
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
324
+
325
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
326
+ text_input_ids, untruncated_ids
327
+ ):
328
+ removed_text = self.tokenizer.batch_decode(
329
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
330
+ )
331
+ logger.warning(
332
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
333
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
334
+ )
335
+
336
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
337
+ attention_mask = text_inputs.attention_mask.to(device)
338
+ else:
339
+ attention_mask = None
340
+
341
+ if clip_skip is None:
342
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
343
+ prompt_embeds = prompt_embeds[0]
344
+ else:
345
+ prompt_embeds = self.text_encoder(
346
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
347
+ )
348
+ # Access the `hidden_states` first, that contains a tuple of
349
+ # all the hidden states from the encoder layers. Then index into
350
+ # the tuple to access the hidden states from the desired layer.
351
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
352
+ # We also need to apply the final LayerNorm here to not mess with the
353
+ # representations. The `last_hidden_states` that we typically use for
354
+ # obtaining the final prompt representations passes through the LayerNorm
355
+ # layer.
356
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
357
+
358
+ if self.text_encoder is not None:
359
+ prompt_embeds_dtype = self.text_encoder.dtype
360
+ elif self.unet is not None:
361
+ prompt_embeds_dtype = self.unet.dtype
362
+ else:
363
+ prompt_embeds_dtype = prompt_embeds.dtype
364
+
365
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
366
+
367
+ bs_embed, seq_len, _ = prompt_embeds.shape
368
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
369
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
370
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
371
+
372
+ # get unconditional embeddings for classifier free guidance
373
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
374
+ uncond_tokens: List[str]
375
+ if negative_prompt is None:
376
+ uncond_tokens = [""] * batch_size
377
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
378
+ raise TypeError(
379
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
380
+ f" {type(prompt)}."
381
+ )
382
+ elif isinstance(negative_prompt, str):
383
+ uncond_tokens = [negative_prompt]
384
+ elif batch_size != len(negative_prompt):
385
+ raise ValueError(
386
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
387
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
388
+ " the batch size of `prompt`."
389
+ )
390
+ else:
391
+ uncond_tokens = negative_prompt
392
+
393
+ # textual inversion: process multi-vector tokens if necessary
394
+ if isinstance(self, TextualInversionLoaderMixin):
395
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
396
+
397
+ max_length = prompt_embeds.shape[1]
398
+ uncond_input = self.tokenizer(
399
+ uncond_tokens,
400
+ padding="max_length",
401
+ max_length=max_length,
402
+ truncation=True,
403
+ return_tensors="pt",
404
+ )
405
+
406
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
407
+ attention_mask = uncond_input.attention_mask.to(device)
408
+ else:
409
+ attention_mask = None
410
+
411
+ negative_prompt_embeds = self.text_encoder(
412
+ uncond_input.input_ids.to(device),
413
+ attention_mask=attention_mask,
414
+ )
415
+ negative_prompt_embeds = negative_prompt_embeds[0]
416
+
417
+ if do_classifier_free_guidance:
418
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
419
+ seq_len = negative_prompt_embeds.shape[1]
420
+
421
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
422
+
423
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
425
+
426
+ if self.text_encoder is not None:
427
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
428
+ # Retrieve the original scale by scaling back the LoRA layers
429
+ unscale_lora_layers(self.text_encoder, lora_scale)
430
+
431
+ return prompt_embeds, negative_prompt_embeds
432
+
433
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
434
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
435
+ dtype = next(self.image_encoder.parameters()).dtype
436
+
437
+ if not isinstance(image, torch.Tensor):
438
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
439
+
440
+ image = image.to(device=device, dtype=dtype)
441
+ if output_hidden_states:
442
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
443
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
444
+ uncond_image_enc_hidden_states = self.image_encoder(
445
+ torch.zeros_like(image), output_hidden_states=True
446
+ ).hidden_states[-2]
447
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
448
+ num_images_per_prompt, dim=0
449
+ )
450
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
451
+ else:
452
+ image_embeds = self.image_encoder(image).image_embeds
453
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
454
+ uncond_image_embeds = torch.zeros_like(image_embeds)
455
+
456
+ return image_embeds, uncond_image_embeds
457
+
458
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
459
+ def prepare_ip_adapter_image_embeds(
460
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
461
+ ):
462
+ image_embeds = []
463
+ if do_classifier_free_guidance:
464
+ negative_image_embeds = []
465
+ if ip_adapter_image_embeds is None:
466
+ if not isinstance(ip_adapter_image, list):
467
+ ip_adapter_image = [ip_adapter_image]
468
+
469
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
470
+ raise ValueError(
471
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
472
+ )
473
+
474
+ for single_ip_adapter_image, image_proj_layer in zip(
475
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
476
+ ):
477
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
478
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
479
+ single_ip_adapter_image, device, 1, output_hidden_state
480
+ )
481
+
482
+ image_embeds.append(single_image_embeds[None, :])
483
+ if do_classifier_free_guidance:
484
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
485
+ else:
486
+ for single_image_embeds in ip_adapter_image_embeds:
487
+ if do_classifier_free_guidance:
488
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
489
+ negative_image_embeds.append(single_negative_image_embeds)
490
+ image_embeds.append(single_image_embeds)
491
+
492
+ ip_adapter_image_embeds = []
493
+ for i, single_image_embeds in enumerate(image_embeds):
494
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
495
+ if do_classifier_free_guidance:
496
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
497
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
498
+
499
+ single_image_embeds = single_image_embeds.to(device=device)
500
+ ip_adapter_image_embeds.append(single_image_embeds)
501
+
502
+ return ip_adapter_image_embeds
503
+
504
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
505
+ def run_safety_checker(self, image, device, dtype):
506
+ if self.safety_checker is None:
507
+ has_nsfw_concept = None
508
+ else:
509
+ if torch.is_tensor(image):
510
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
511
+ else:
512
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
513
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
514
+ image, has_nsfw_concept = self.safety_checker(
515
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
516
+ )
517
+ return image, has_nsfw_concept
518
+
519
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
520
+ def prepare_extra_step_kwargs(self, generator, eta):
521
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
522
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
523
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
524
+ # and should be between [0, 1]
525
+
526
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
527
+ extra_step_kwargs = {}
528
+ if accepts_eta:
529
+ extra_step_kwargs["eta"] = eta
530
+
531
+ # check if the scheduler accepts generator
532
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
533
+ if accepts_generator:
534
+ extra_step_kwargs["generator"] = generator
535
+ return extra_step_kwargs
536
+
537
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
538
+ def get_timesteps(self, num_inference_steps, strength, device):
539
+ # get the original timestep using init_timestep
540
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
541
+
542
+ t_start = max(num_inference_steps - init_timestep, 0)
543
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
544
+ if hasattr(self.scheduler, "set_begin_index"):
545
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
546
+
547
+ return timesteps, num_inference_steps - t_start
548
+
549
+ def check_inputs(
550
+ self,
551
+ prompt,
552
+ image,
553
+ mask_image,
554
+ height,
555
+ width,
556
+ output_type,
557
+ negative_prompt=None,
558
+ prompt_embeds=None,
559
+ negative_prompt_embeds=None,
560
+ ip_adapter_image=None,
561
+ ip_adapter_image_embeds=None,
562
+ controlnet_conditioning_scale=1.0,
563
+ control_guidance_start=0.0,
564
+ control_guidance_end=1.0,
565
+ callback_on_step_end_tensor_inputs=None,
566
+ padding_mask_crop=None,
567
+ ):
568
+ if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
569
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
570
+
571
+ if callback_on_step_end_tensor_inputs is not None and not all(
572
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
573
+ ):
574
+ raise ValueError(
575
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
576
+ )
577
+
578
+ if prompt is not None and prompt_embeds is not None:
579
+ raise ValueError(
580
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
581
+ " only forward one of the two."
582
+ )
583
+ elif prompt is None and prompt_embeds is None:
584
+ raise ValueError(
585
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
586
+ )
587
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
588
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
589
+
590
+ if negative_prompt is not None and negative_prompt_embeds is not None:
591
+ raise ValueError(
592
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
593
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
594
+ )
595
+
596
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
597
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
598
+ raise ValueError(
599
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
600
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
601
+ f" {negative_prompt_embeds.shape}."
602
+ )
603
+
604
+ if padding_mask_crop is not None:
605
+ if not isinstance(image, PIL.Image.Image):
606
+ raise ValueError(
607
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
608
+ )
609
+ if not isinstance(mask_image, PIL.Image.Image):
610
+ raise ValueError(
611
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
612
+ f" {type(mask_image)}."
613
+ )
614
+ if output_type != "pil":
615
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
616
+
617
+ # `prompt` needs more sophisticated handling when there are multiple
618
+ # conditionings.
619
+ if isinstance(self.controlnet, MultiControlNetModel):
620
+ if isinstance(prompt, list):
621
+ logger.warning(
622
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
623
+ " prompts. The conditionings will be fixed across the prompts."
624
+ )
625
+
626
+ # Check `image`
627
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
628
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
629
+ )
630
+ if (
631
+ isinstance(self.controlnet, ControlNetModel)
632
+ or is_compiled
633
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
634
+ ):
635
+ self.check_image(image, prompt, prompt_embeds)
636
+ elif (
637
+ isinstance(self.controlnet, MultiControlNetModel)
638
+ or is_compiled
639
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
640
+ ):
641
+ if not isinstance(image, list):
642
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
643
+
644
+ # When `image` is a nested list:
645
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
646
+ elif any(isinstance(i, list) for i in image):
647
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
648
+ elif len(image) != len(self.controlnet.nets):
649
+ raise ValueError(
650
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
651
+ )
652
+
653
+ for image_ in image:
654
+ self.check_image(image_, prompt, prompt_embeds)
655
+ else:
656
+ assert False
657
+
658
+ # Check `controlnet_conditioning_scale`
659
+ if (
660
+ isinstance(self.controlnet, ControlNetModel)
661
+ or is_compiled
662
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
663
+ ):
664
+ if not isinstance(controlnet_conditioning_scale, float):
665
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
666
+ elif (
667
+ isinstance(self.controlnet, MultiControlNetModel)
668
+ or is_compiled
669
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
670
+ ):
671
+ if isinstance(controlnet_conditioning_scale, list):
672
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
673
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
674
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
675
+ self.controlnet.nets
676
+ ):
677
+ raise ValueError(
678
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
679
+ " the same length as the number of controlnets"
680
+ )
681
+ else:
682
+ assert False
683
+
684
+ if len(control_guidance_start) != len(control_guidance_end):
685
+ raise ValueError(
686
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
687
+ )
688
+
689
+ if isinstance(self.controlnet, MultiControlNetModel):
690
+ if len(control_guidance_start) != len(self.controlnet.nets):
691
+ raise ValueError(
692
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
693
+ )
694
+
695
+ for start, end in zip(control_guidance_start, control_guidance_end):
696
+ if start >= end:
697
+ raise ValueError(
698
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
699
+ )
700
+ if start < 0.0:
701
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
702
+ if end > 1.0:
703
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
704
+
705
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
706
+ raise ValueError(
707
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
708
+ )
709
+
710
+ if ip_adapter_image_embeds is not None:
711
+ if not isinstance(ip_adapter_image_embeds, list):
712
+ raise ValueError(
713
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
714
+ )
715
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
716
+ raise ValueError(
717
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
718
+ )
719
+
720
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
721
+ def check_image(self, image, prompt, prompt_embeds):
722
+ image_is_pil = isinstance(image, PIL.Image.Image)
723
+ image_is_tensor = isinstance(image, torch.Tensor)
724
+ image_is_np = isinstance(image, np.ndarray)
725
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
726
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
727
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
728
+
729
+ if (
730
+ not image_is_pil
731
+ and not image_is_tensor
732
+ and not image_is_np
733
+ and not image_is_pil_list
734
+ and not image_is_tensor_list
735
+ and not image_is_np_list
736
+ ):
737
+ raise TypeError(
738
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
739
+ )
740
+
741
+ if image_is_pil:
742
+ image_batch_size = 1
743
+ else:
744
+ image_batch_size = len(image)
745
+
746
+ if prompt is not None and isinstance(prompt, str):
747
+ prompt_batch_size = 1
748
+ elif prompt is not None and isinstance(prompt, list):
749
+ prompt_batch_size = len(prompt)
750
+ elif prompt_embeds is not None:
751
+ prompt_batch_size = prompt_embeds.shape[0]
752
+
753
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
754
+ raise ValueError(
755
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
756
+ )
757
+
758
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint.StableDiffusionControlNetInpaintPipeline.prepare_control_image
759
+ def prepare_control_image(
760
+ self,
761
+ image,
762
+ width,
763
+ height,
764
+ batch_size,
765
+ num_images_per_prompt,
766
+ device,
767
+ dtype,
768
+ crops_coords,
769
+ resize_mode,
770
+ do_classifier_free_guidance=False,
771
+ guess_mode=False,
772
+ ):
773
+ image = self.control_image_processor.preprocess(
774
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
775
+ ).to(dtype=torch.float32)
776
+ image_batch_size = image.shape[0]
777
+
778
+ if image_batch_size == 1:
779
+ repeat_by = batch_size
780
+ else:
781
+ # image batch size is the same as prompt batch size
782
+ repeat_by = num_images_per_prompt
783
+
784
+ image = image.repeat_interleave(repeat_by, dim=0)
785
+
786
+ image = image.to(device=device, dtype=dtype)
787
+
788
+ if do_classifier_free_guidance and not guess_mode:
789
+ image = torch.cat([image] * 2)
790
+
791
+ return image
792
+
793
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
794
+ def prepare_latents(
795
+ self,
796
+ batch_size,
797
+ num_channels_latents,
798
+ height,
799
+ width,
800
+ dtype,
801
+ device,
802
+ generator,
803
+ latents=None,
804
+ image=None,
805
+ timestep=None,
806
+ is_strength_max=True,
807
+ return_noise=False,
808
+ return_image_latents=False,
809
+ ):
810
+ shape = (
811
+ batch_size,
812
+ num_channels_latents,
813
+ int(height) // self.vae_scale_factor,
814
+ int(width) // self.vae_scale_factor,
815
+ )
816
+ if isinstance(generator, list) and len(generator) != batch_size:
817
+ raise ValueError(
818
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
819
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
820
+ )
821
+
822
+ if (image is None or timestep is None) and not is_strength_max:
823
+ raise ValueError(
824
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
825
+ "However, either the image or the noise timestep has not been provided."
826
+ )
827
+
828
+ if return_image_latents or (latents is None and not is_strength_max):
829
+ image = image.to(device=device, dtype=dtype)
830
+
831
+ if image.shape[1] == 4:
832
+ image_latents = image
833
+ else:
834
+ image_latents = self._encode_vae_image(image=image, generator=generator)
835
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
836
+
837
+ if latents is None:
838
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
839
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
840
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
841
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
842
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
843
+ else:
844
+ noise = latents.to(device)
845
+ latents = noise * self.scheduler.init_noise_sigma
846
+
847
+ outputs = (latents,)
848
+
849
+ if return_noise:
850
+ outputs += (noise,)
851
+
852
+ if return_image_latents:
853
+ outputs += (image_latents,)
854
+
855
+ return outputs
856
+
857
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
858
+ def prepare_mask_latents(
859
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
860
+ ):
861
+ # resize the mask to latents shape as we concatenate the mask to the latents
862
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
863
+ # and half precision
864
+ mask = torch.nn.functional.interpolate(
865
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
866
+ )
867
+ mask = mask.to(device=device, dtype=dtype)
868
+
869
+ masked_image = masked_image.to(device=device, dtype=dtype)
870
+
871
+ if masked_image.shape[1] == 4:
872
+ masked_image_latents = masked_image
873
+ else:
874
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
875
+
876
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
877
+ if mask.shape[0] < batch_size:
878
+ if not batch_size % mask.shape[0] == 0:
879
+ raise ValueError(
880
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
881
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
882
+ " of masks that you pass is divisible by the total requested batch size."
883
+ )
884
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
885
+ if masked_image_latents.shape[0] < batch_size:
886
+ if not batch_size % masked_image_latents.shape[0] == 0:
887
+ raise ValueError(
888
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
889
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
890
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
891
+ )
892
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
893
+
894
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
895
+ masked_image_latents = (
896
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
897
+ )
898
+
899
+ # aligning device to prevent device errors when concating it with the latent model input
900
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
901
+ return mask, masked_image_latents
902
+
903
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
904
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
905
+ if isinstance(generator, list):
906
+ image_latents = [
907
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
908
+ for i in range(image.shape[0])
909
+ ]
910
+ image_latents = torch.cat(image_latents, dim=0)
911
+ else:
912
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
913
+
914
+ image_latents = self.vae.config.scaling_factor * image_latents
915
+
916
+ return image_latents
917
+
918
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
919
+ def get_guidance_scale_embedding(
920
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
921
+ ) -> torch.Tensor:
922
+ """
923
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
924
+
925
+ Args:
926
+ w (`torch.Tensor`):
927
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
928
+ embedding_dim (`int`, *optional*, defaults to 512):
929
+ Dimension of the embeddings to generate.
930
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
931
+ Data type of the generated embeddings.
932
+
933
+ Returns:
934
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
935
+ """
936
+ assert len(w.shape) == 1
937
+ w = w * 1000.0
938
+
939
+ half_dim = embedding_dim // 2
940
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
941
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
942
+ emb = w.to(dtype)[:, None] * emb[None, :]
943
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
944
+ if embedding_dim % 2 == 1: # zero pad
945
+ emb = torch.nn.functional.pad(emb, (0, 1))
946
+ assert emb.shape == (w.shape[0], embedding_dim)
947
+ return emb
948
+
949
+ @property
950
+ def guidance_scale(self):
951
+ return self._guidance_scale
952
+
953
+ @property
954
+ def clip_skip(self):
955
+ return self._clip_skip
956
+
957
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
958
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
959
+ # corresponds to doing no classifier free guidance.
960
+ @property
961
+ def do_classifier_free_guidance(self):
962
+ return self._guidance_scale > 1
963
+
964
+ @property
965
+ def cross_attention_kwargs(self):
966
+ return self._cross_attention_kwargs
967
+
968
+ @property
969
+ def num_timesteps(self):
970
+ return self._num_timesteps
971
+
972
+ @torch.no_grad()
973
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
974
+ def __call__(
975
+ self,
976
+ prompt: Union[str, List[str]] = None,
977
+ image: PipelineImageInput = None,
978
+ mask_image: PipelineImageInput = None,
979
+ control_image: PipelineImageInput = None,
980
+ height: Optional[int] = None,
981
+ width: Optional[int] = None,
982
+ padding_mask_crop: Optional[int] = None,
983
+ strength: float = 1.0,
984
+ num_inference_steps: int = 50,
985
+ guidance_scale: float = 7.5,
986
+ negative_prompt: Optional[Union[str, List[str]]] = None,
987
+ num_images_per_prompt: Optional[int] = 1,
988
+ eta: float = 0.0,
989
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
990
+ latents: Optional[torch.Tensor] = None,
991
+ prompt_embeds: Optional[torch.Tensor] = None,
992
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
993
+ ip_adapter_image: Optional[PipelineImageInput] = None,
994
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
995
+ output_type: Optional[str] = "pil",
996
+ return_dict: bool = True,
997
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
998
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
999
+ control_guidance_start: Union[float, List[float]] = 0.0,
1000
+ control_guidance_end: Union[float, List[float]] = 1.0,
1001
+ clip_skip: Optional[int] = None,
1002
+ callback_on_step_end: Optional[
1003
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1004
+ ] = None,
1005
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1006
+ pag_scale: float = 3.0,
1007
+ pag_adaptive_scale: float = 0.0,
1008
+ ):
1009
+ r"""
1010
+ The call function to the pipeline for generation.
1011
+
1012
+ Args:
1013
+ prompt (`str` or `List[str]`, *optional*):
1014
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1015
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`,
1016
+ `List[PIL.Image.Image]`, or `List[np.ndarray]`):
1017
+ `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both
1018
+ NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
1019
+ list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or
1020
+ a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image
1021
+ latents as `image`, but if passing latents directly it is not encoded again.
1022
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`,
1023
+ `List[PIL.Image.Image]`, or `List[np.ndarray]`):
1024
+ `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask
1025
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
1026
+ single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one
1027
+ color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B,
1028
+ H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H,
1029
+ W, 1)`, or `(H, W)`.
1030
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`,
1031
+ `List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`):
1032
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1033
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1034
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1035
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1036
+ images must be passed as a list such that each element of the list can be correctly batched for input
1037
+ to a single ControlNet.
1038
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1039
+ The height in pixels of the generated image.
1040
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1041
+ The width in pixels of the generated image.
1042
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1043
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1044
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1045
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1046
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1047
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1048
+ the image is large and contain information irrelevant for inpainting, such as background.
1049
+ strength (`float`, *optional*, defaults to 1.0):
1050
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1051
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
1052
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
1053
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
1054
+ essentially ignores `image`.
1055
+ num_inference_steps (`int`, *optional*, defaults to 50):
1056
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1057
+ expense of slower inference.
1058
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1059
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1060
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1061
+ negative_prompt (`str` or `List[str]`, *optional*):
1062
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1063
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1064
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1065
+ The number of images to generate per prompt.
1066
+ eta (`float`, *optional*, defaults to 0.0):
1067
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
1068
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1069
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1070
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1071
+ generation deterministic.
1072
+ latents (`torch.Tensor`, *optional*):
1073
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1074
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1075
+ tensor is generated by sampling using the supplied random `generator`.
1076
+ prompt_embeds (`torch.Tensor`, *optional*):
1077
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1078
+ provided, text embeddings are generated from the `prompt` input argument.
1079
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1080
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1081
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1082
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1083
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1084
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1085
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1086
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1087
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1088
+ output_type (`str`, *optional*, defaults to `"pil"`):
1089
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1090
+ return_dict (`bool`, *optional*, defaults to `True`):
1091
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1092
+ plain tuple.
1093
+ cross_attention_kwargs (`dict`, *optional*):
1094
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1095
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1096
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
1097
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1098
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1099
+ the corresponding scale as a list.
1100
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1101
+ The percentage of total steps at which the ControlNet starts applying.
1102
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1103
+ The percentage of total steps at which the ControlNet stops applying.
1104
+ clip_skip (`int`, *optional*):
1105
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1106
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1107
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1108
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1109
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1110
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1111
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1112
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1113
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1114
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1115
+ `._callback_tensor_inputs` attribute of your pipeline class.
1116
+ pag_scale (`float`, *optional*, defaults to 3.0):
1117
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
1118
+ guidance will not be used.
1119
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
1120
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
1121
+ used.
1122
+
1123
+ Examples:
1124
+
1125
+ Returns:
1126
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1127
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1128
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1129
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1130
+ "not-safe-for-work" (nsfw) content.
1131
+ """
1132
+
1133
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1134
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1135
+
1136
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1137
+
1138
+ # align format for control guidance
1139
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1140
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1141
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1142
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1143
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1144
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1145
+ control_guidance_start, control_guidance_end = (
1146
+ mult * [control_guidance_start],
1147
+ mult * [control_guidance_end],
1148
+ )
1149
+
1150
+ # 1. Check inputs. Raise error if not correct
1151
+ self.check_inputs(
1152
+ prompt,
1153
+ control_image,
1154
+ mask_image,
1155
+ height,
1156
+ width,
1157
+ output_type,
1158
+ negative_prompt,
1159
+ prompt_embeds,
1160
+ negative_prompt_embeds,
1161
+ ip_adapter_image,
1162
+ ip_adapter_image_embeds,
1163
+ controlnet_conditioning_scale,
1164
+ control_guidance_start,
1165
+ control_guidance_end,
1166
+ callback_on_step_end_tensor_inputs,
1167
+ padding_mask_crop,
1168
+ )
1169
+
1170
+ self._guidance_scale = guidance_scale
1171
+ self._clip_skip = clip_skip
1172
+ self._cross_attention_kwargs = cross_attention_kwargs
1173
+ self._pag_scale = pag_scale
1174
+ self._pag_adaptive_scale = pag_adaptive_scale
1175
+
1176
+ # 2. Define call parameters
1177
+ if prompt is not None and isinstance(prompt, str):
1178
+ batch_size = 1
1179
+ elif prompt is not None and isinstance(prompt, list):
1180
+ batch_size = len(prompt)
1181
+ else:
1182
+ batch_size = prompt_embeds.shape[0]
1183
+
1184
+ if padding_mask_crop is not None:
1185
+ height, width = self.image_processor.get_default_height_width(image, height, width)
1186
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1187
+ resize_mode = "fill"
1188
+ else:
1189
+ crops_coords = None
1190
+ resize_mode = "default"
1191
+
1192
+ device = self._execution_device
1193
+
1194
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1195
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1196
+
1197
+ # 3. Encode input prompt
1198
+ text_encoder_lora_scale = (
1199
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1200
+ )
1201
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1202
+ prompt,
1203
+ device,
1204
+ num_images_per_prompt,
1205
+ self.do_classifier_free_guidance,
1206
+ negative_prompt,
1207
+ prompt_embeds=prompt_embeds,
1208
+ negative_prompt_embeds=negative_prompt_embeds,
1209
+ lora_scale=text_encoder_lora_scale,
1210
+ clip_skip=self.clip_skip,
1211
+ )
1212
+ # For classifier free guidance, we need to do two forward passes.
1213
+ # Here we concatenate the unconditional and text embeddings into a single batch
1214
+ # to avoid doing two forward passes
1215
+ if self.do_perturbed_attention_guidance:
1216
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
1217
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
1218
+ )
1219
+ elif self.do_classifier_free_guidance:
1220
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1221
+
1222
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1223
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1224
+ ip_adapter_image,
1225
+ ip_adapter_image_embeds,
1226
+ device,
1227
+ batch_size * num_images_per_prompt,
1228
+ self.do_classifier_free_guidance,
1229
+ )
1230
+
1231
+ # 4. Prepare control image
1232
+ if isinstance(controlnet, ControlNetModel):
1233
+ control_image = self.prepare_control_image(
1234
+ image=control_image,
1235
+ width=width,
1236
+ height=height,
1237
+ batch_size=batch_size * num_images_per_prompt,
1238
+ num_images_per_prompt=num_images_per_prompt,
1239
+ device=device,
1240
+ dtype=controlnet.dtype,
1241
+ crops_coords=crops_coords,
1242
+ resize_mode=resize_mode,
1243
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1244
+ guess_mode=False,
1245
+ )
1246
+ elif isinstance(controlnet, MultiControlNetModel):
1247
+ control_images = []
1248
+
1249
+ for control_image_ in control_image:
1250
+ control_image_ = self.prepare_control_image(
1251
+ image=control_image_,
1252
+ width=width,
1253
+ height=height,
1254
+ batch_size=batch_size * num_images_per_prompt,
1255
+ num_images_per_prompt=num_images_per_prompt,
1256
+ device=device,
1257
+ dtype=controlnet.dtype,
1258
+ crops_coords=crops_coords,
1259
+ resize_mode=resize_mode,
1260
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1261
+ guess_mode=False,
1262
+ )
1263
+
1264
+ control_images.append(control_image_)
1265
+
1266
+ control_image = control_images
1267
+ else:
1268
+ assert False
1269
+
1270
+ # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
1271
+ original_image = image
1272
+ init_image = self.image_processor.preprocess(
1273
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1274
+ )
1275
+ init_image = init_image.to(dtype=torch.float32)
1276
+
1277
+ mask = self.mask_processor.preprocess(
1278
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1279
+ )
1280
+
1281
+ masked_image = init_image * (mask < 0.5)
1282
+ _, _, height, width = init_image.shape
1283
+
1284
+ # 5. Prepare timesteps
1285
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1286
+ timesteps, num_inference_steps = self.get_timesteps(
1287
+ num_inference_steps=num_inference_steps, strength=strength, device=device
1288
+ )
1289
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1290
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1291
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1292
+ is_strength_max = strength == 1.0
1293
+ self._num_timesteps = len(timesteps)
1294
+
1295
+ # 6. Prepare latent variables
1296
+ num_channels_latents = self.vae.config.latent_channels
1297
+ num_channels_unet = self.unet.config.in_channels
1298
+ return_image_latents = num_channels_unet == 4
1299
+ latents_outputs = self.prepare_latents(
1300
+ batch_size * num_images_per_prompt,
1301
+ num_channels_latents,
1302
+ height,
1303
+ width,
1304
+ prompt_embeds.dtype,
1305
+ device,
1306
+ generator,
1307
+ latents,
1308
+ image=init_image,
1309
+ timestep=latent_timestep,
1310
+ is_strength_max=is_strength_max,
1311
+ return_noise=True,
1312
+ return_image_latents=return_image_latents,
1313
+ )
1314
+
1315
+ if return_image_latents:
1316
+ latents, noise, image_latents = latents_outputs
1317
+ else:
1318
+ latents, noise = latents_outputs
1319
+
1320
+ # 7. Prepare mask latent variables
1321
+ mask, masked_image_latents = self.prepare_mask_latents(
1322
+ mask,
1323
+ masked_image,
1324
+ batch_size * num_images_per_prompt,
1325
+ height,
1326
+ width,
1327
+ prompt_embeds.dtype,
1328
+ device,
1329
+ generator,
1330
+ self.do_classifier_free_guidance,
1331
+ )
1332
+
1333
+ # 7.1 Check that sizes of mask, masked image and latents match
1334
+ if num_channels_unet == 9:
1335
+ # default case for runwayml/stable-diffusion-inpainting
1336
+ num_channels_mask = mask.shape[1]
1337
+ num_channels_masked_image = masked_image_latents.shape[1]
1338
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1339
+ raise ValueError(
1340
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1341
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1342
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1343
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1344
+ " `pipeline.unet` or your `mask_image` or `image` input."
1345
+ )
1346
+ elif num_channels_unet != 4:
1347
+ raise ValueError(
1348
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1349
+ )
1350
+
1351
+ # 7.2 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1352
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1353
+
1354
+ # 7.3 Prepare embeddings
1355
+ # ip-adapter
1356
+ if ip_adapter_image_embeds is not None:
1357
+ for i, image_embeds in enumerate(ip_adapter_image_embeds):
1358
+ negative_image_embeds = None
1359
+ if self.do_classifier_free_guidance:
1360
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
1361
+
1362
+ if self.do_perturbed_attention_guidance:
1363
+ image_embeds = self._prepare_perturbed_attention_guidance(
1364
+ image_embeds, negative_image_embeds, self.do_classifier_free_guidance
1365
+ )
1366
+ elif self.do_classifier_free_guidance:
1367
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
1368
+ image_embeds = image_embeds.to(device)
1369
+ ip_adapter_image_embeds[i] = image_embeds
1370
+
1371
+ added_cond_kwargs = (
1372
+ {"image_embeds": ip_adapter_image_embeds}
1373
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1374
+ else None
1375
+ )
1376
+
1377
+ # control image
1378
+ control_images = control_image if isinstance(control_image, list) else [control_image]
1379
+ for i, single_control_image in enumerate(control_images):
1380
+ if self.do_classifier_free_guidance:
1381
+ single_control_image = single_control_image.chunk(2)[0]
1382
+
1383
+ if self.do_perturbed_attention_guidance:
1384
+ single_control_image = self._prepare_perturbed_attention_guidance(
1385
+ single_control_image, single_control_image, self.do_classifier_free_guidance
1386
+ )
1387
+ elif self.do_classifier_free_guidance:
1388
+ single_control_image = torch.cat([single_control_image] * 2)
1389
+ single_control_image = single_control_image.to(device)
1390
+ control_images[i] = single_control_image
1391
+
1392
+ control_image = control_images if isinstance(control_image, list) else control_images[0]
1393
+ controlnet_prompt_embeds = prompt_embeds
1394
+
1395
+ # 7.4 Create tensor stating which controlnets to keep
1396
+ controlnet_keep = []
1397
+ for i in range(len(timesteps)):
1398
+ keeps = [
1399
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1400
+ for s, e in zip(control_guidance_start, control_guidance_end)
1401
+ ]
1402
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1403
+
1404
+ # 7.5 Optionally get Guidance Scale Embedding
1405
+ timestep_cond = None
1406
+ if self.unet.config.time_cond_proj_dim is not None:
1407
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1408
+ timestep_cond = self.get_guidance_scale_embedding(
1409
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1410
+ ).to(device=device, dtype=latents.dtype)
1411
+
1412
+ # 8. Denoising loop
1413
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1414
+ if self.do_perturbed_attention_guidance:
1415
+ original_attn_proc = self.unet.attn_processors
1416
+ self._set_pag_attn_processor(
1417
+ pag_applied_layers=self.pag_applied_layers,
1418
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1419
+ )
1420
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1421
+ for i, t in enumerate(timesteps):
1422
+ # expand the latents if we are doing classifier free guidance
1423
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
1424
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1425
+
1426
+ # controlnet(s) inference
1427
+ control_model_input = latent_model_input
1428
+
1429
+ if isinstance(controlnet_keep[i], list):
1430
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1431
+ else:
1432
+ controlnet_cond_scale = controlnet_conditioning_scale
1433
+ if isinstance(controlnet_cond_scale, list):
1434
+ controlnet_cond_scale = controlnet_cond_scale[0]
1435
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1436
+
1437
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1438
+ control_model_input,
1439
+ t,
1440
+ encoder_hidden_states=controlnet_prompt_embeds,
1441
+ controlnet_cond=control_image,
1442
+ conditioning_scale=cond_scale,
1443
+ guess_mode=False,
1444
+ return_dict=False,
1445
+ )
1446
+
1447
+ # concat latents, mask, masked_image_latents in the channel dimension
1448
+ if num_channels_unet == 9:
1449
+ first_dim_size = latent_model_input.shape[0]
1450
+ # Ensure mask and masked_image_latents have the right dimensions
1451
+ if mask.shape[0] < first_dim_size:
1452
+ repeat_factor = (first_dim_size + mask.shape[0] - 1) // mask.shape[0]
1453
+ mask = mask.repeat(repeat_factor, 1, 1, 1)[:first_dim_size]
1454
+ if masked_image_latents.shape[0] < first_dim_size:
1455
+ repeat_factor = (
1456
+ first_dim_size + masked_image_latents.shape[0] - 1
1457
+ ) // masked_image_latents.shape[0]
1458
+ masked_image_latents = masked_image_latents.repeat(repeat_factor, 1, 1, 1)[:first_dim_size]
1459
+ # Perform the concatenation
1460
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1461
+
1462
+ # Predict noise residual
1463
+ noise_pred = self.unet(
1464
+ latent_model_input,
1465
+ t,
1466
+ encoder_hidden_states=prompt_embeds,
1467
+ timestep_cond=timestep_cond,
1468
+ cross_attention_kwargs=self.cross_attention_kwargs,
1469
+ down_block_additional_residuals=down_block_res_samples,
1470
+ mid_block_additional_residual=mid_block_res_sample,
1471
+ added_cond_kwargs=added_cond_kwargs,
1472
+ return_dict=False,
1473
+ )[0]
1474
+
1475
+ # perform guidance
1476
+ if self.do_perturbed_attention_guidance:
1477
+ noise_pred = self._apply_perturbed_attention_guidance(
1478
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1479
+ )
1480
+ elif self.do_classifier_free_guidance:
1481
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1482
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1483
+
1484
+ # compute the previous noisy sample x_t -> x_t-1
1485
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1486
+
1487
+ if num_channels_unet == 4:
1488
+ init_latents_proper = image_latents
1489
+ if self.do_classifier_free_guidance:
1490
+ init_mask, _ = mask.chunk(2)
1491
+ else:
1492
+ init_mask = mask
1493
+
1494
+ if i < len(timesteps) - 1:
1495
+ noise_timestep = timesteps[i + 1]
1496
+ init_latents_proper = self.scheduler.add_noise(
1497
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1498
+ )
1499
+
1500
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1501
+
1502
+ if callback_on_step_end is not None:
1503
+ callback_kwargs = {}
1504
+ for k in callback_on_step_end_tensor_inputs:
1505
+ callback_kwargs[k] = locals()[k]
1506
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1507
+
1508
+ latents = callback_outputs.pop("latents", latents)
1509
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1510
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1511
+
1512
+ # call the callback, if provided
1513
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1514
+ progress_bar.update()
1515
+
1516
+ if XLA_AVAILABLE:
1517
+ xm.mark_step()
1518
+
1519
+ # If we do sequential model offloading, let's offload unet and controlnet
1520
+ # manually for max memory savings
1521
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1522
+ self.unet.to("cpu")
1523
+ self.controlnet.to("cpu")
1524
+ empty_device_cache()
1525
+
1526
+ if not output_type == "latent":
1527
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1528
+ 0
1529
+ ]
1530
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1531
+ else:
1532
+ image = latents
1533
+ has_nsfw_concept = None
1534
+
1535
+ if has_nsfw_concept is None:
1536
+ do_denormalize = [True] * image.shape[0]
1537
+ else:
1538
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1539
+
1540
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1541
+
1542
+ if padding_mask_crop is not None:
1543
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1544
+
1545
+ # Offload all models
1546
+ self.maybe_free_model_hooks()
1547
+
1548
+ if self.do_perturbed_attention_guidance:
1549
+ self.unet.set_attn_processor(original_attn_proc)
1550
+
1551
+ if not return_dict:
1552
+ return (image, has_nsfw_concept)
1553
+
1554
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPTextModel,
26
+ CLIPTextModelWithProjection,
27
+ CLIPTokenizer,
28
+ CLIPVisionModelWithProjection,
29
+ )
30
+
31
+ from diffusers.utils.import_utils import is_invisible_watermark_available
32
+
33
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
34
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
35
+ from ...loaders import (
36
+ FromSingleFileMixin,
37
+ IPAdapterMixin,
38
+ StableDiffusionXLLoraLoaderMixin,
39
+ TextualInversionLoaderMixin,
40
+ )
41
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
42
+ from ...models.attention_processor import (
43
+ AttnProcessor2_0,
44
+ XFormersAttnProcessor,
45
+ )
46
+ from ...models.lora import adjust_lora_scale_text_encoder
47
+ from ...schedulers import KarrasDiffusionSchedulers
48
+ from ...utils import (
49
+ USE_PEFT_BACKEND,
50
+ logging,
51
+ replace_example_docstring,
52
+ scale_lora_layers,
53
+ unscale_lora_layers,
54
+ )
55
+ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
56
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
57
+ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
58
+ from .pag_utils import PAGMixin
59
+
60
+
61
+ if is_invisible_watermark_available():
62
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
+
64
+
65
+ from ...utils import is_torch_xla_available
66
+
67
+
68
+ if is_torch_xla_available():
69
+ import torch_xla.core.xla_model as xm
70
+
71
+ XLA_AVAILABLE = True
72
+ else:
73
+ XLA_AVAILABLE = False
74
+
75
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
76
+
77
+
78
+ EXAMPLE_DOC_STRING = """
79
+ Examples:
80
+ ```py
81
+ >>> # !pip install opencv-python transformers accelerate
82
+ >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL
83
+ >>> from diffusers.utils import load_image
84
+ >>> import numpy as np
85
+ >>> import torch
86
+
87
+ >>> import cv2
88
+ >>> from PIL import Image
89
+
90
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
91
+ >>> negative_prompt = "low quality, bad quality, sketches"
92
+
93
+ >>> # download an image
94
+ >>> image = load_image(
95
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
96
+ ... )
97
+
98
+ >>> # initialize the models and pipeline
99
+ >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
100
+ >>> controlnet = ControlNetModel.from_pretrained(
101
+ ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
102
+ ... )
103
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
104
+ >>> pipe = AutoPipelineForText2Image.from_pretrained(
105
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
106
+ ... controlnet=controlnet,
107
+ ... vae=vae,
108
+ ... torch_dtype=torch.float16,
109
+ ... enable_pag=True,
110
+ ... )
111
+ >>> pipe.enable_model_cpu_offload()
112
+
113
+ >>> # get canny image
114
+ >>> image = np.array(image)
115
+ >>> image = cv2.Canny(image, 100, 200)
116
+ >>> image = image[:, :, None]
117
+ >>> image = np.concatenate([image, image, image], axis=2)
118
+ >>> canny_image = Image.fromarray(image)
119
+
120
+ >>> # generate image
121
+ >>> image = pipe(
122
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3
123
+ ... ).images[0]
124
+ ```
125
+ """
126
+
127
+
128
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
129
+ def retrieve_timesteps(
130
+ scheduler,
131
+ num_inference_steps: Optional[int] = None,
132
+ device: Optional[Union[str, torch.device]] = None,
133
+ timesteps: Optional[List[int]] = None,
134
+ sigmas: Optional[List[float]] = None,
135
+ **kwargs,
136
+ ):
137
+ r"""
138
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
139
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
140
+
141
+ Args:
142
+ scheduler (`SchedulerMixin`):
143
+ The scheduler to get timesteps from.
144
+ num_inference_steps (`int`):
145
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
146
+ must be `None`.
147
+ device (`str` or `torch.device`, *optional*):
148
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
149
+ timesteps (`List[int]`, *optional*):
150
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
151
+ `num_inference_steps` and `sigmas` must be `None`.
152
+ sigmas (`List[float]`, *optional*):
153
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
154
+ `num_inference_steps` and `timesteps` must be `None`.
155
+
156
+ Returns:
157
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
158
+ second element is the number of inference steps.
159
+ """
160
+ if timesteps is not None and sigmas is not None:
161
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
162
+ if timesteps is not None:
163
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
164
+ if not accepts_timesteps:
165
+ raise ValueError(
166
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
167
+ f" timestep schedules. Please check whether you are using the correct scheduler."
168
+ )
169
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ num_inference_steps = len(timesteps)
172
+ elif sigmas is not None:
173
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
174
+ if not accept_sigmas:
175
+ raise ValueError(
176
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
177
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
178
+ )
179
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
180
+ timesteps = scheduler.timesteps
181
+ num_inference_steps = len(timesteps)
182
+ else:
183
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
184
+ timesteps = scheduler.timesteps
185
+ return timesteps, num_inference_steps
186
+
187
+
188
+ class StableDiffusionXLControlNetPAGPipeline(
189
+ DiffusionPipeline,
190
+ StableDiffusionMixin,
191
+ TextualInversionLoaderMixin,
192
+ StableDiffusionXLLoraLoaderMixin,
193
+ IPAdapterMixin,
194
+ FromSingleFileMixin,
195
+ PAGMixin,
196
+ ):
197
+ r"""
198
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
199
+
200
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
201
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
202
+
203
+ The pipeline also inherits the following loading methods:
204
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
205
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
206
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
207
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
208
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
209
+
210
+ Args:
211
+ vae ([`AutoencoderKL`]):
212
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
213
+ text_encoder ([`~transformers.CLIPTextModel`]):
214
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
215
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
216
+ Second frozen text-encoder
217
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
218
+ tokenizer ([`~transformers.CLIPTokenizer`]):
219
+ A `CLIPTokenizer` to tokenize text.
220
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
221
+ A `CLIPTokenizer` to tokenize text.
222
+ unet ([`UNet2DConditionModel`]):
223
+ A `UNet2DConditionModel` to denoise the encoded image latents.
224
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
225
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
226
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
227
+ additional conditioning.
228
+ scheduler ([`SchedulerMixin`]):
229
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
230
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
231
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
232
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
233
+ `stabilityai/stable-diffusion-xl-base-1-0`.
234
+ add_watermarker (`bool`, *optional*):
235
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
236
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
237
+ watermarker is used.
238
+ """
239
+
240
+ # leave controlnet out on purpose because it iterates with unet
241
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
242
+ _optional_components = [
243
+ "tokenizer",
244
+ "tokenizer_2",
245
+ "text_encoder",
246
+ "text_encoder_2",
247
+ "feature_extractor",
248
+ "image_encoder",
249
+ ]
250
+ _callback_tensor_inputs = [
251
+ "latents",
252
+ "prompt_embeds",
253
+ "negative_prompt_embeds",
254
+ "add_text_embeds",
255
+ "add_time_ids",
256
+ "negative_pooled_prompt_embeds",
257
+ "negative_add_time_ids",
258
+ ]
259
+
260
+ def __init__(
261
+ self,
262
+ vae: AutoencoderKL,
263
+ text_encoder: CLIPTextModel,
264
+ text_encoder_2: CLIPTextModelWithProjection,
265
+ tokenizer: CLIPTokenizer,
266
+ tokenizer_2: CLIPTokenizer,
267
+ unet: UNet2DConditionModel,
268
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
269
+ scheduler: KarrasDiffusionSchedulers,
270
+ force_zeros_for_empty_prompt: bool = True,
271
+ add_watermarker: Optional[bool] = None,
272
+ feature_extractor: CLIPImageProcessor = None,
273
+ image_encoder: CLIPVisionModelWithProjection = None,
274
+ pag_applied_layers: Union[str, List[str]] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid"
275
+ ):
276
+ super().__init__()
277
+
278
+ if isinstance(controlnet, (list, tuple)):
279
+ controlnet = MultiControlNetModel(controlnet)
280
+
281
+ self.register_modules(
282
+ vae=vae,
283
+ text_encoder=text_encoder,
284
+ text_encoder_2=text_encoder_2,
285
+ tokenizer=tokenizer,
286
+ tokenizer_2=tokenizer_2,
287
+ unet=unet,
288
+ controlnet=controlnet,
289
+ scheduler=scheduler,
290
+ feature_extractor=feature_extractor,
291
+ image_encoder=image_encoder,
292
+ )
293
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
294
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
295
+ self.control_image_processor = VaeImageProcessor(
296
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
297
+ )
298
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
299
+
300
+ if add_watermarker:
301
+ self.watermark = StableDiffusionXLWatermarker()
302
+ else:
303
+ self.watermark = None
304
+
305
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
306
+ self.set_pag_applied_layers(pag_applied_layers)
307
+
308
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
309
+ def encode_prompt(
310
+ self,
311
+ prompt: str,
312
+ prompt_2: Optional[str] = None,
313
+ device: Optional[torch.device] = None,
314
+ num_images_per_prompt: int = 1,
315
+ do_classifier_free_guidance: bool = True,
316
+ negative_prompt: Optional[str] = None,
317
+ negative_prompt_2: Optional[str] = None,
318
+ prompt_embeds: Optional[torch.Tensor] = None,
319
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
320
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
321
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
322
+ lora_scale: Optional[float] = None,
323
+ clip_skip: Optional[int] = None,
324
+ ):
325
+ r"""
326
+ Encodes the prompt into text encoder hidden states.
327
+
328
+ Args:
329
+ prompt (`str` or `List[str]`, *optional*):
330
+ prompt to be encoded
331
+ prompt_2 (`str` or `List[str]`, *optional*):
332
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
333
+ used in both text-encoders
334
+ device: (`torch.device`):
335
+ torch device
336
+ num_images_per_prompt (`int`):
337
+ number of images that should be generated per prompt
338
+ do_classifier_free_guidance (`bool`):
339
+ whether to use classifier free guidance or not
340
+ negative_prompt (`str` or `List[str]`, *optional*):
341
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
342
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
343
+ less than `1`).
344
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
345
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
346
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
347
+ prompt_embeds (`torch.Tensor`, *optional*):
348
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
349
+ provided, text embeddings will be generated from `prompt` input argument.
350
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
351
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
352
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
353
+ argument.
354
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
355
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
356
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
357
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
358
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
359
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
360
+ input argument.
361
+ lora_scale (`float`, *optional*):
362
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
363
+ clip_skip (`int`, *optional*):
364
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
365
+ the output of the pre-final layer will be used for computing the prompt embeddings.
366
+ """
367
+ device = device or self._execution_device
368
+
369
+ # set lora scale so that monkey patched LoRA
370
+ # function of text encoder can correctly access it
371
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
372
+ self._lora_scale = lora_scale
373
+
374
+ # dynamically adjust the LoRA scale
375
+ if self.text_encoder is not None:
376
+ if not USE_PEFT_BACKEND:
377
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
378
+ else:
379
+ scale_lora_layers(self.text_encoder, lora_scale)
380
+
381
+ if self.text_encoder_2 is not None:
382
+ if not USE_PEFT_BACKEND:
383
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
384
+ else:
385
+ scale_lora_layers(self.text_encoder_2, lora_scale)
386
+
387
+ prompt = [prompt] if isinstance(prompt, str) else prompt
388
+
389
+ if prompt is not None:
390
+ batch_size = len(prompt)
391
+ else:
392
+ batch_size = prompt_embeds.shape[0]
393
+
394
+ # Define tokenizers and text encoders
395
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
396
+ text_encoders = (
397
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
398
+ )
399
+
400
+ if prompt_embeds is None:
401
+ prompt_2 = prompt_2 or prompt
402
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
403
+
404
+ # textual inversion: process multi-vector tokens if necessary
405
+ prompt_embeds_list = []
406
+ prompts = [prompt, prompt_2]
407
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
408
+ if isinstance(self, TextualInversionLoaderMixin):
409
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
410
+
411
+ text_inputs = tokenizer(
412
+ prompt,
413
+ padding="max_length",
414
+ max_length=tokenizer.model_max_length,
415
+ truncation=True,
416
+ return_tensors="pt",
417
+ )
418
+
419
+ text_input_ids = text_inputs.input_ids
420
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
421
+
422
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
423
+ text_input_ids, untruncated_ids
424
+ ):
425
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
426
+ logger.warning(
427
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
428
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
429
+ )
430
+
431
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
432
+
433
+ # We are only ALWAYS interested in the pooled output of the final text encoder
434
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
435
+ pooled_prompt_embeds = prompt_embeds[0]
436
+
437
+ if clip_skip is None:
438
+ prompt_embeds = prompt_embeds.hidden_states[-2]
439
+ else:
440
+ # "2" because SDXL always indexes from the penultimate layer.
441
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
442
+
443
+ prompt_embeds_list.append(prompt_embeds)
444
+
445
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
446
+
447
+ # get unconditional embeddings for classifier free guidance
448
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
449
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
450
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
451
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
452
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
453
+ negative_prompt = negative_prompt or ""
454
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
455
+
456
+ # normalize str to list
457
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
458
+ negative_prompt_2 = (
459
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
460
+ )
461
+
462
+ uncond_tokens: List[str]
463
+ if prompt is not None and type(prompt) is not type(negative_prompt):
464
+ raise TypeError(
465
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
466
+ f" {type(prompt)}."
467
+ )
468
+ elif batch_size != len(negative_prompt):
469
+ raise ValueError(
470
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
471
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
472
+ " the batch size of `prompt`."
473
+ )
474
+ else:
475
+ uncond_tokens = [negative_prompt, negative_prompt_2]
476
+
477
+ negative_prompt_embeds_list = []
478
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
479
+ if isinstance(self, TextualInversionLoaderMixin):
480
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
481
+
482
+ max_length = prompt_embeds.shape[1]
483
+ uncond_input = tokenizer(
484
+ negative_prompt,
485
+ padding="max_length",
486
+ max_length=max_length,
487
+ truncation=True,
488
+ return_tensors="pt",
489
+ )
490
+
491
+ negative_prompt_embeds = text_encoder(
492
+ uncond_input.input_ids.to(device),
493
+ output_hidden_states=True,
494
+ )
495
+
496
+ # We are only ALWAYS interested in the pooled output of the final text encoder
497
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
498
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
499
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
500
+
501
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
502
+
503
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
504
+
505
+ if self.text_encoder_2 is not None:
506
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
507
+ else:
508
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
509
+
510
+ bs_embed, seq_len, _ = prompt_embeds.shape
511
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
512
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
513
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
514
+
515
+ if do_classifier_free_guidance:
516
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
517
+ seq_len = negative_prompt_embeds.shape[1]
518
+
519
+ if self.text_encoder_2 is not None:
520
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
521
+ else:
522
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
523
+
524
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
525
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
526
+
527
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
528
+ bs_embed * num_images_per_prompt, -1
529
+ )
530
+ if do_classifier_free_guidance:
531
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
532
+ bs_embed * num_images_per_prompt, -1
533
+ )
534
+
535
+ if self.text_encoder is not None:
536
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
537
+ # Retrieve the original scale by scaling back the LoRA layers
538
+ unscale_lora_layers(self.text_encoder, lora_scale)
539
+
540
+ if self.text_encoder_2 is not None:
541
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
542
+ # Retrieve the original scale by scaling back the LoRA layers
543
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
544
+
545
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
546
+
547
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
548
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
549
+ dtype = next(self.image_encoder.parameters()).dtype
550
+
551
+ if not isinstance(image, torch.Tensor):
552
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
553
+
554
+ image = image.to(device=device, dtype=dtype)
555
+ if output_hidden_states:
556
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
557
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
558
+ uncond_image_enc_hidden_states = self.image_encoder(
559
+ torch.zeros_like(image), output_hidden_states=True
560
+ ).hidden_states[-2]
561
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
562
+ num_images_per_prompt, dim=0
563
+ )
564
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
565
+ else:
566
+ image_embeds = self.image_encoder(image).image_embeds
567
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
568
+ uncond_image_embeds = torch.zeros_like(image_embeds)
569
+
570
+ return image_embeds, uncond_image_embeds
571
+
572
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
573
+ def prepare_ip_adapter_image_embeds(
574
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
575
+ ):
576
+ image_embeds = []
577
+ if do_classifier_free_guidance:
578
+ negative_image_embeds = []
579
+ if ip_adapter_image_embeds is None:
580
+ if not isinstance(ip_adapter_image, list):
581
+ ip_adapter_image = [ip_adapter_image]
582
+
583
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
584
+ raise ValueError(
585
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
586
+ )
587
+
588
+ for single_ip_adapter_image, image_proj_layer in zip(
589
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
590
+ ):
591
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
592
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
593
+ single_ip_adapter_image, device, 1, output_hidden_state
594
+ )
595
+
596
+ image_embeds.append(single_image_embeds[None, :])
597
+ if do_classifier_free_guidance:
598
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
599
+ else:
600
+ for single_image_embeds in ip_adapter_image_embeds:
601
+ if do_classifier_free_guidance:
602
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
603
+ negative_image_embeds.append(single_negative_image_embeds)
604
+ image_embeds.append(single_image_embeds)
605
+
606
+ ip_adapter_image_embeds = []
607
+ for i, single_image_embeds in enumerate(image_embeds):
608
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
609
+ if do_classifier_free_guidance:
610
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
611
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
612
+
613
+ single_image_embeds = single_image_embeds.to(device=device)
614
+ ip_adapter_image_embeds.append(single_image_embeds)
615
+
616
+ return ip_adapter_image_embeds
617
+
618
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
619
+ def prepare_extra_step_kwargs(self, generator, eta):
620
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
621
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
622
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
623
+ # and should be between [0, 1]
624
+
625
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
626
+ extra_step_kwargs = {}
627
+ if accepts_eta:
628
+ extra_step_kwargs["eta"] = eta
629
+
630
+ # check if the scheduler accepts generator
631
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
632
+ if accepts_generator:
633
+ extra_step_kwargs["generator"] = generator
634
+ return extra_step_kwargs
635
+
636
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs
637
+ def check_inputs(
638
+ self,
639
+ prompt,
640
+ prompt_2,
641
+ image,
642
+ callback_steps,
643
+ negative_prompt=None,
644
+ negative_prompt_2=None,
645
+ prompt_embeds=None,
646
+ negative_prompt_embeds=None,
647
+ pooled_prompt_embeds=None,
648
+ ip_adapter_image=None,
649
+ ip_adapter_image_embeds=None,
650
+ negative_pooled_prompt_embeds=None,
651
+ controlnet_conditioning_scale=1.0,
652
+ control_guidance_start=0.0,
653
+ control_guidance_end=1.0,
654
+ callback_on_step_end_tensor_inputs=None,
655
+ ):
656
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
657
+ raise ValueError(
658
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
659
+ f" {type(callback_steps)}."
660
+ )
661
+
662
+ if callback_on_step_end_tensor_inputs is not None and not all(
663
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
664
+ ):
665
+ raise ValueError(
666
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
667
+ )
668
+
669
+ if prompt is not None and prompt_embeds is not None:
670
+ raise ValueError(
671
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
672
+ " only forward one of the two."
673
+ )
674
+ elif prompt_2 is not None and prompt_embeds is not None:
675
+ raise ValueError(
676
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
677
+ " only forward one of the two."
678
+ )
679
+ elif prompt is None and prompt_embeds is None:
680
+ raise ValueError(
681
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
682
+ )
683
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
684
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
685
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
686
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
687
+
688
+ if negative_prompt is not None and negative_prompt_embeds is not None:
689
+ raise ValueError(
690
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
691
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
692
+ )
693
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
694
+ raise ValueError(
695
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
696
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
697
+ )
698
+
699
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
700
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
701
+ raise ValueError(
702
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
703
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
704
+ f" {negative_prompt_embeds.shape}."
705
+ )
706
+
707
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
708
+ raise ValueError(
709
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
710
+ )
711
+
712
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
713
+ raise ValueError(
714
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
715
+ )
716
+
717
+ # `prompt` needs more sophisticated handling when there are multiple
718
+ # conditionings.
719
+ if isinstance(self.controlnet, MultiControlNetModel):
720
+ if isinstance(prompt, list):
721
+ logger.warning(
722
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
723
+ " prompts. The conditionings will be fixed across the prompts."
724
+ )
725
+
726
+ # Check `image`
727
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
728
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
729
+ )
730
+ if (
731
+ isinstance(self.controlnet, ControlNetModel)
732
+ or is_compiled
733
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
734
+ ):
735
+ self.check_image(image, prompt, prompt_embeds)
736
+ elif (
737
+ isinstance(self.controlnet, MultiControlNetModel)
738
+ or is_compiled
739
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
740
+ ):
741
+ if not isinstance(image, list):
742
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
743
+
744
+ # When `image` is a nested list:
745
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
746
+ elif any(isinstance(i, list) for i in image):
747
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
748
+ elif len(image) != len(self.controlnet.nets):
749
+ raise ValueError(
750
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
751
+ )
752
+
753
+ for image_ in image:
754
+ self.check_image(image_, prompt, prompt_embeds)
755
+ else:
756
+ assert False
757
+
758
+ # Check `controlnet_conditioning_scale`
759
+ if (
760
+ isinstance(self.controlnet, ControlNetModel)
761
+ or is_compiled
762
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
763
+ ):
764
+ if not isinstance(controlnet_conditioning_scale, float):
765
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
766
+ elif (
767
+ isinstance(self.controlnet, MultiControlNetModel)
768
+ or is_compiled
769
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
770
+ ):
771
+ if isinstance(controlnet_conditioning_scale, list):
772
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
773
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
774
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
775
+ self.controlnet.nets
776
+ ):
777
+ raise ValueError(
778
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
779
+ " the same length as the number of controlnets"
780
+ )
781
+ else:
782
+ assert False
783
+
784
+ if not isinstance(control_guidance_start, (tuple, list)):
785
+ control_guidance_start = [control_guidance_start]
786
+
787
+ if not isinstance(control_guidance_end, (tuple, list)):
788
+ control_guidance_end = [control_guidance_end]
789
+
790
+ if len(control_guidance_start) != len(control_guidance_end):
791
+ raise ValueError(
792
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
793
+ )
794
+
795
+ if isinstance(self.controlnet, MultiControlNetModel):
796
+ if len(control_guidance_start) != len(self.controlnet.nets):
797
+ raise ValueError(
798
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
799
+ )
800
+
801
+ for start, end in zip(control_guidance_start, control_guidance_end):
802
+ if start >= end:
803
+ raise ValueError(
804
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
805
+ )
806
+ if start < 0.0:
807
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
808
+ if end > 1.0:
809
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
810
+
811
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
812
+ raise ValueError(
813
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
814
+ )
815
+
816
+ if ip_adapter_image_embeds is not None:
817
+ if not isinstance(ip_adapter_image_embeds, list):
818
+ raise ValueError(
819
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
820
+ )
821
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
822
+ raise ValueError(
823
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
824
+ )
825
+
826
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
827
+ def check_image(self, image, prompt, prompt_embeds):
828
+ image_is_pil = isinstance(image, PIL.Image.Image)
829
+ image_is_tensor = isinstance(image, torch.Tensor)
830
+ image_is_np = isinstance(image, np.ndarray)
831
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
832
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
833
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
834
+
835
+ if (
836
+ not image_is_pil
837
+ and not image_is_tensor
838
+ and not image_is_np
839
+ and not image_is_pil_list
840
+ and not image_is_tensor_list
841
+ and not image_is_np_list
842
+ ):
843
+ raise TypeError(
844
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
845
+ )
846
+
847
+ if image_is_pil:
848
+ image_batch_size = 1
849
+ else:
850
+ image_batch_size = len(image)
851
+
852
+ if prompt is not None and isinstance(prompt, str):
853
+ prompt_batch_size = 1
854
+ elif prompt is not None and isinstance(prompt, list):
855
+ prompt_batch_size = len(prompt)
856
+ elif prompt_embeds is not None:
857
+ prompt_batch_size = prompt_embeds.shape[0]
858
+
859
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
860
+ raise ValueError(
861
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
862
+ )
863
+
864
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
865
+ def prepare_image(
866
+ self,
867
+ image,
868
+ width,
869
+ height,
870
+ batch_size,
871
+ num_images_per_prompt,
872
+ device,
873
+ dtype,
874
+ do_classifier_free_guidance=False,
875
+ guess_mode=False,
876
+ ):
877
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
878
+ image_batch_size = image.shape[0]
879
+
880
+ if image_batch_size == 1:
881
+ repeat_by = batch_size
882
+ else:
883
+ # image batch size is the same as prompt batch size
884
+ repeat_by = num_images_per_prompt
885
+
886
+ image = image.repeat_interleave(repeat_by, dim=0)
887
+
888
+ image = image.to(device=device, dtype=dtype)
889
+
890
+ if do_classifier_free_guidance and not guess_mode:
891
+ image = torch.cat([image] * 2)
892
+
893
+ return image
894
+
895
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
896
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
897
+ shape = (
898
+ batch_size,
899
+ num_channels_latents,
900
+ int(height) // self.vae_scale_factor,
901
+ int(width) // self.vae_scale_factor,
902
+ )
903
+ if isinstance(generator, list) and len(generator) != batch_size:
904
+ raise ValueError(
905
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
906
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
907
+ )
908
+
909
+ if latents is None:
910
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
911
+ else:
912
+ latents = latents.to(device)
913
+
914
+ # scale the initial noise by the standard deviation required by the scheduler
915
+ latents = latents * self.scheduler.init_noise_sigma
916
+ return latents
917
+
918
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
919
+ def _get_add_time_ids(
920
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
921
+ ):
922
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
923
+
924
+ passed_add_embed_dim = (
925
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
926
+ )
927
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
928
+
929
+ if expected_add_embed_dim != passed_add_embed_dim:
930
+ raise ValueError(
931
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
932
+ )
933
+
934
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
935
+ return add_time_ids
936
+
937
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
938
+ def upcast_vae(self):
939
+ dtype = self.vae.dtype
940
+ self.vae.to(dtype=torch.float32)
941
+ use_torch_2_0_or_xformers = isinstance(
942
+ self.vae.decoder.mid_block.attentions[0].processor,
943
+ (
944
+ AttnProcessor2_0,
945
+ XFormersAttnProcessor,
946
+ ),
947
+ )
948
+ # if xformers or torch_2_0 is used attention block does not need
949
+ # to be in float32 which can save lots of memory
950
+ if use_torch_2_0_or_xformers:
951
+ self.vae.post_quant_conv.to(dtype)
952
+ self.vae.decoder.conv_in.to(dtype)
953
+ self.vae.decoder.mid_block.to(dtype)
954
+
955
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
956
+ def get_guidance_scale_embedding(
957
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
958
+ ) -> torch.Tensor:
959
+ """
960
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
961
+
962
+ Args:
963
+ w (`torch.Tensor`):
964
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
965
+ embedding_dim (`int`, *optional*, defaults to 512):
966
+ Dimension of the embeddings to generate.
967
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
968
+ Data type of the generated embeddings.
969
+
970
+ Returns:
971
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
972
+ """
973
+ assert len(w.shape) == 1
974
+ w = w * 1000.0
975
+
976
+ half_dim = embedding_dim // 2
977
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
978
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
979
+ emb = w.to(dtype)[:, None] * emb[None, :]
980
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
981
+ if embedding_dim % 2 == 1: # zero pad
982
+ emb = torch.nn.functional.pad(emb, (0, 1))
983
+ assert emb.shape == (w.shape[0], embedding_dim)
984
+ return emb
985
+
986
+ @property
987
+ def guidance_scale(self):
988
+ return self._guidance_scale
989
+
990
+ @property
991
+ def clip_skip(self):
992
+ return self._clip_skip
993
+
994
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
995
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
996
+ # corresponds to doing no classifier free guidance.
997
+ @property
998
+ def do_classifier_free_guidance(self):
999
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1000
+
1001
+ @property
1002
+ def cross_attention_kwargs(self):
1003
+ return self._cross_attention_kwargs
1004
+
1005
+ @property
1006
+ def denoising_end(self):
1007
+ return self._denoising_end
1008
+
1009
+ @property
1010
+ def num_timesteps(self):
1011
+ return self._num_timesteps
1012
+
1013
+ @torch.no_grad()
1014
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1015
+ def __call__(
1016
+ self,
1017
+ prompt: Union[str, List[str]] = None,
1018
+ prompt_2: Optional[Union[str, List[str]]] = None,
1019
+ image: PipelineImageInput = None,
1020
+ height: Optional[int] = None,
1021
+ width: Optional[int] = None,
1022
+ num_inference_steps: int = 50,
1023
+ timesteps: List[int] = None,
1024
+ sigmas: List[float] = None,
1025
+ denoising_end: Optional[float] = None,
1026
+ guidance_scale: float = 5.0,
1027
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1028
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1029
+ num_images_per_prompt: Optional[int] = 1,
1030
+ eta: float = 0.0,
1031
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1032
+ latents: Optional[torch.Tensor] = None,
1033
+ prompt_embeds: Optional[torch.Tensor] = None,
1034
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1035
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1036
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1037
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1038
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1039
+ output_type: Optional[str] = "pil",
1040
+ return_dict: bool = True,
1041
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1042
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1043
+ control_guidance_start: Union[float, List[float]] = 0.0,
1044
+ control_guidance_end: Union[float, List[float]] = 1.0,
1045
+ original_size: Tuple[int, int] = None,
1046
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1047
+ target_size: Tuple[int, int] = None,
1048
+ negative_original_size: Optional[Tuple[int, int]] = None,
1049
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1050
+ negative_target_size: Optional[Tuple[int, int]] = None,
1051
+ clip_skip: Optional[int] = None,
1052
+ callback_on_step_end: Optional[
1053
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1054
+ ] = None,
1055
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1056
+ pag_scale: float = 3.0,
1057
+ pag_adaptive_scale: float = 0.0,
1058
+ ):
1059
+ r"""
1060
+ The call function to the pipeline for generation.
1061
+
1062
+ Args:
1063
+ prompt (`str` or `List[str]`, *optional*):
1064
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1065
+ prompt_2 (`str` or `List[str]`, *optional*):
1066
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1067
+ used in both text-encoders.
1068
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1069
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1070
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1071
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1072
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1073
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1074
+ images must be passed as a list such that each element of the list can be correctly batched for input
1075
+ to a single ControlNet.
1076
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1077
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
1078
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1079
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1080
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1081
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
1082
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1083
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1084
+ num_inference_steps (`int`, *optional*, defaults to 50):
1085
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1086
+ expense of slower inference.
1087
+ timesteps (`List[int]`, *optional*):
1088
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1089
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1090
+ passed will be used. Must be in descending order.
1091
+ sigmas (`List[float]`, *optional*):
1092
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1093
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1094
+ will be used.
1095
+ denoising_end (`float`, *optional*):
1096
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1097
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1098
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1099
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1100
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1101
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1102
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1103
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1104
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1105
+ negative_prompt (`str` or `List[str]`, *optional*):
1106
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1107
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1108
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1109
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
1110
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
1111
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1112
+ The number of images to generate per prompt.
1113
+ eta (`float`, *optional*, defaults to 0.0):
1114
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
1115
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1116
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1117
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1118
+ generation deterministic.
1119
+ latents (`torch.Tensor`, *optional*):
1120
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1121
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1122
+ tensor is generated by sampling using the supplied random `generator`.
1123
+ prompt_embeds (`torch.Tensor`, *optional*):
1124
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1125
+ provided, text embeddings are generated from the `prompt` input argument.
1126
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1127
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1128
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1129
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1130
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1131
+ not provided, pooled text embeddings are generated from `prompt` input argument.
1132
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1133
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1134
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1135
+ argument.
1136
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1137
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1138
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1139
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1140
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1141
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1142
+ output_type (`str`, *optional*, defaults to `"pil"`):
1143
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1144
+ return_dict (`bool`, *optional*, defaults to `True`):
1145
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1146
+ plain tuple.
1147
+ cross_attention_kwargs (`dict`, *optional*):
1148
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1149
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1150
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1151
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1152
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1153
+ the corresponding scale as a list.
1154
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1155
+ The percentage of total steps at which the ControlNet starts applying.
1156
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1157
+ The percentage of total steps at which the ControlNet stops applying.
1158
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1159
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1160
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1161
+ explained in section 2.2 of
1162
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1163
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1164
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1165
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1166
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1167
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1168
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1169
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1170
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1171
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1172
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1173
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1174
+ micro-conditioning as explained in section 2.2 of
1175
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1176
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1177
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1178
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1179
+ micro-conditioning as explained in section 2.2 of
1180
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1181
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1182
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1183
+ To negatively condition the generation process based on a target image resolution. It should be as same
1184
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1185
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1186
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1187
+ clip_skip (`int`, *optional*):
1188
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1189
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1190
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1191
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1192
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1193
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1194
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1195
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1196
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1197
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1198
+ `._callback_tensor_inputs` attribute of your pipeline class.
1199
+ pag_scale (`float`, *optional*, defaults to 3.0):
1200
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
1201
+ guidance will not be used.
1202
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
1203
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
1204
+ used.
1205
+
1206
+ Examples:
1207
+
1208
+ Returns:
1209
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1210
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1211
+ otherwise a `tuple` is returned containing the output images.
1212
+ """
1213
+
1214
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1215
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1216
+
1217
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1218
+
1219
+ # align format for control guidance
1220
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1221
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1222
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1223
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1224
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1225
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1226
+ control_guidance_start, control_guidance_end = (
1227
+ mult * [control_guidance_start],
1228
+ mult * [control_guidance_end],
1229
+ )
1230
+
1231
+ # 1. Check inputs. Raise error if not correct
1232
+ self.check_inputs(
1233
+ prompt,
1234
+ prompt_2,
1235
+ image,
1236
+ None,
1237
+ negative_prompt,
1238
+ negative_prompt_2,
1239
+ prompt_embeds,
1240
+ negative_prompt_embeds,
1241
+ pooled_prompt_embeds,
1242
+ ip_adapter_image,
1243
+ ip_adapter_image_embeds,
1244
+ negative_pooled_prompt_embeds,
1245
+ controlnet_conditioning_scale,
1246
+ control_guidance_start,
1247
+ control_guidance_end,
1248
+ callback_on_step_end_tensor_inputs,
1249
+ )
1250
+
1251
+ self._guidance_scale = guidance_scale
1252
+ self._clip_skip = clip_skip
1253
+ self._cross_attention_kwargs = cross_attention_kwargs
1254
+ self._denoising_end = denoising_end
1255
+ self._pag_scale = pag_scale
1256
+ self._pag_adaptive_scale = pag_adaptive_scale
1257
+
1258
+ # 2. Define call parameters
1259
+ if prompt is not None and isinstance(prompt, str):
1260
+ batch_size = 1
1261
+ elif prompt is not None and isinstance(prompt, list):
1262
+ batch_size = len(prompt)
1263
+ else:
1264
+ batch_size = prompt_embeds.shape[0]
1265
+
1266
+ device = self._execution_device
1267
+
1268
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1269
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1270
+
1271
+ # 3.1 Encode input prompt
1272
+ text_encoder_lora_scale = (
1273
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1274
+ )
1275
+ (
1276
+ prompt_embeds,
1277
+ negative_prompt_embeds,
1278
+ pooled_prompt_embeds,
1279
+ negative_pooled_prompt_embeds,
1280
+ ) = self.encode_prompt(
1281
+ prompt,
1282
+ prompt_2,
1283
+ device,
1284
+ num_images_per_prompt,
1285
+ self.do_classifier_free_guidance,
1286
+ negative_prompt,
1287
+ negative_prompt_2,
1288
+ prompt_embeds=prompt_embeds,
1289
+ negative_prompt_embeds=negative_prompt_embeds,
1290
+ pooled_prompt_embeds=pooled_prompt_embeds,
1291
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1292
+ lora_scale=text_encoder_lora_scale,
1293
+ clip_skip=self.clip_skip,
1294
+ )
1295
+
1296
+ # 3.2 Encode ip_adapter_image
1297
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1298
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1299
+ ip_adapter_image,
1300
+ ip_adapter_image_embeds,
1301
+ device,
1302
+ batch_size * num_images_per_prompt,
1303
+ self.do_classifier_free_guidance,
1304
+ )
1305
+
1306
+ # 4. Prepare image
1307
+ if isinstance(controlnet, ControlNetModel):
1308
+ image = self.prepare_image(
1309
+ image=image,
1310
+ width=width,
1311
+ height=height,
1312
+ batch_size=batch_size * num_images_per_prompt,
1313
+ num_images_per_prompt=num_images_per_prompt,
1314
+ device=device,
1315
+ dtype=controlnet.dtype,
1316
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1317
+ guess_mode=False,
1318
+ )
1319
+ height, width = image.shape[-2:]
1320
+ elif isinstance(controlnet, MultiControlNetModel):
1321
+ images = []
1322
+
1323
+ for image_ in image:
1324
+ image_ = self.prepare_image(
1325
+ image=image_,
1326
+ width=width,
1327
+ height=height,
1328
+ batch_size=batch_size * num_images_per_prompt,
1329
+ num_images_per_prompt=num_images_per_prompt,
1330
+ device=device,
1331
+ dtype=controlnet.dtype,
1332
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1333
+ guess_mode=False,
1334
+ )
1335
+
1336
+ images.append(image_)
1337
+
1338
+ image = images
1339
+ height, width = image[0].shape[-2:]
1340
+ else:
1341
+ assert False
1342
+
1343
+ # 5. Prepare timesteps
1344
+ timesteps, num_inference_steps = retrieve_timesteps(
1345
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1346
+ )
1347
+ self._num_timesteps = len(timesteps)
1348
+
1349
+ # 6. Prepare latent variables
1350
+ num_channels_latents = self.unet.config.in_channels
1351
+ latents = self.prepare_latents(
1352
+ batch_size * num_images_per_prompt,
1353
+ num_channels_latents,
1354
+ height,
1355
+ width,
1356
+ prompt_embeds.dtype,
1357
+ device,
1358
+ generator,
1359
+ latents,
1360
+ )
1361
+
1362
+ # 6.1 Optionally get Guidance Scale Embedding
1363
+ timestep_cond = None
1364
+ if self.unet.config.time_cond_proj_dim is not None:
1365
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1366
+ timestep_cond = self.get_guidance_scale_embedding(
1367
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1368
+ ).to(device=device, dtype=latents.dtype)
1369
+
1370
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1371
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1372
+
1373
+ # 7.1 Create tensor stating which controlnets to keep
1374
+ controlnet_keep = []
1375
+ for i in range(len(timesteps)):
1376
+ keeps = [
1377
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1378
+ for s, e in zip(control_guidance_start, control_guidance_end)
1379
+ ]
1380
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1381
+
1382
+ # 7.2 Prepare added time ids & embeddings
1383
+ if isinstance(image, list):
1384
+ original_size = original_size or image[0].shape[-2:]
1385
+ else:
1386
+ original_size = original_size or image.shape[-2:]
1387
+ target_size = target_size or (height, width)
1388
+
1389
+ add_text_embeds = pooled_prompt_embeds
1390
+ if self.text_encoder_2 is None:
1391
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1392
+ else:
1393
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1394
+
1395
+ add_time_ids = self._get_add_time_ids(
1396
+ original_size,
1397
+ crops_coords_top_left,
1398
+ target_size,
1399
+ dtype=prompt_embeds.dtype,
1400
+ text_encoder_projection_dim=text_encoder_projection_dim,
1401
+ )
1402
+
1403
+ if negative_original_size is not None and negative_target_size is not None:
1404
+ negative_add_time_ids = self._get_add_time_ids(
1405
+ negative_original_size,
1406
+ negative_crops_coords_top_left,
1407
+ negative_target_size,
1408
+ dtype=prompt_embeds.dtype,
1409
+ text_encoder_projection_dim=text_encoder_projection_dim,
1410
+ )
1411
+ else:
1412
+ negative_add_time_ids = add_time_ids
1413
+
1414
+ images = image if isinstance(image, list) else [image]
1415
+ for i, single_image in enumerate(images):
1416
+ if self.do_classifier_free_guidance:
1417
+ single_image = single_image.chunk(2)[0]
1418
+
1419
+ if self.do_perturbed_attention_guidance:
1420
+ single_image = self._prepare_perturbed_attention_guidance(
1421
+ single_image, single_image, self.do_classifier_free_guidance
1422
+ )
1423
+ elif self.do_classifier_free_guidance:
1424
+ single_image = torch.cat([single_image] * 2)
1425
+ single_image = single_image.to(device)
1426
+ images[i] = single_image
1427
+
1428
+ image = images if isinstance(image, list) else images[0]
1429
+
1430
+ if ip_adapter_image_embeds is not None:
1431
+ for i, image_embeds in enumerate(ip_adapter_image_embeds):
1432
+ negative_image_embeds = None
1433
+ if self.do_classifier_free_guidance:
1434
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
1435
+
1436
+ if self.do_perturbed_attention_guidance:
1437
+ image_embeds = self._prepare_perturbed_attention_guidance(
1438
+ image_embeds, negative_image_embeds, self.do_classifier_free_guidance
1439
+ )
1440
+ elif self.do_classifier_free_guidance:
1441
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
1442
+ image_embeds = image_embeds.to(device)
1443
+ ip_adapter_image_embeds[i] = image_embeds
1444
+
1445
+ if self.do_perturbed_attention_guidance:
1446
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
1447
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
1448
+ )
1449
+ add_text_embeds = self._prepare_perturbed_attention_guidance(
1450
+ add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance
1451
+ )
1452
+ add_time_ids = self._prepare_perturbed_attention_guidance(
1453
+ add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance
1454
+ )
1455
+ elif self.do_classifier_free_guidance:
1456
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1457
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1458
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1459
+
1460
+ prompt_embeds = prompt_embeds.to(device)
1461
+ add_text_embeds = add_text_embeds.to(device)
1462
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1463
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1464
+
1465
+ controlnet_prompt_embeds = prompt_embeds
1466
+ controlnet_added_cond_kwargs = added_cond_kwargs
1467
+
1468
+ # 8. Denoising loop
1469
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1470
+
1471
+ # 8.1 Apply denoising_end
1472
+ if (
1473
+ self.denoising_end is not None
1474
+ and isinstance(self.denoising_end, float)
1475
+ and self.denoising_end > 0
1476
+ and self.denoising_end < 1
1477
+ ):
1478
+ discrete_timestep_cutoff = int(
1479
+ round(
1480
+ self.scheduler.config.num_train_timesteps
1481
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1482
+ )
1483
+ )
1484
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1485
+ timesteps = timesteps[:num_inference_steps]
1486
+
1487
+ if self.do_perturbed_attention_guidance:
1488
+ original_attn_proc = self.unet.attn_processors
1489
+ self._set_pag_attn_processor(
1490
+ pag_applied_layers=self.pag_applied_layers,
1491
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1492
+ )
1493
+
1494
+ is_unet_compiled = is_compiled_module(self.unet)
1495
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1496
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1497
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1498
+ for i, t in enumerate(timesteps):
1499
+ # Relevant thread:
1500
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1501
+ if (
1502
+ torch.cuda.is_available()
1503
+ and (is_unet_compiled and is_controlnet_compiled)
1504
+ and is_torch_higher_equal_2_1
1505
+ ):
1506
+ torch._inductor.cudagraph_mark_step_begin()
1507
+ # expand the latents if we are doing classifier free guidance
1508
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
1509
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1510
+
1511
+ # controlnet(s) inference
1512
+ control_model_input = latent_model_input
1513
+
1514
+ if isinstance(controlnet_keep[i], list):
1515
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1516
+ else:
1517
+ controlnet_cond_scale = controlnet_conditioning_scale
1518
+ if isinstance(controlnet_cond_scale, list):
1519
+ controlnet_cond_scale = controlnet_cond_scale[0]
1520
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1521
+
1522
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1523
+ control_model_input,
1524
+ t,
1525
+ encoder_hidden_states=controlnet_prompt_embeds,
1526
+ controlnet_cond=image,
1527
+ conditioning_scale=cond_scale,
1528
+ guess_mode=False,
1529
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1530
+ return_dict=False,
1531
+ )
1532
+
1533
+ if ip_adapter_image_embeds is not None:
1534
+ added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds
1535
+
1536
+ # predict the noise residual
1537
+ noise_pred = self.unet(
1538
+ latent_model_input,
1539
+ t,
1540
+ encoder_hidden_states=prompt_embeds,
1541
+ timestep_cond=timestep_cond,
1542
+ cross_attention_kwargs=self.cross_attention_kwargs,
1543
+ down_block_additional_residuals=down_block_res_samples,
1544
+ mid_block_additional_residual=mid_block_res_sample,
1545
+ added_cond_kwargs=added_cond_kwargs,
1546
+ return_dict=False,
1547
+ )[0]
1548
+
1549
+ # perform guidance
1550
+ if self.do_perturbed_attention_guidance:
1551
+ noise_pred = self._apply_perturbed_attention_guidance(
1552
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1553
+ )
1554
+ elif self.do_classifier_free_guidance:
1555
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1556
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1557
+
1558
+ # compute the previous noisy sample x_t -> x_t-1
1559
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1560
+
1561
+ if callback_on_step_end is not None:
1562
+ callback_kwargs = {}
1563
+ for k in callback_on_step_end_tensor_inputs:
1564
+ callback_kwargs[k] = locals()[k]
1565
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1566
+
1567
+ latents = callback_outputs.pop("latents", latents)
1568
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1569
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1570
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1571
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1572
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1573
+ )
1574
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1575
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1576
+
1577
+ # call the callback, if provided
1578
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1579
+ progress_bar.update()
1580
+
1581
+ if XLA_AVAILABLE:
1582
+ xm.mark_step()
1583
+
1584
+ if not output_type == "latent":
1585
+ # make sure the VAE is in float32 mode, as it overflows in float16
1586
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1587
+
1588
+ if needs_upcasting:
1589
+ self.upcast_vae()
1590
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1591
+
1592
+ # unscale/denormalize the latents
1593
+ # denormalize with the mean and std if available and not None
1594
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1595
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1596
+ if has_latents_mean and has_latents_std:
1597
+ latents_mean = (
1598
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1599
+ )
1600
+ latents_std = (
1601
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1602
+ )
1603
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1604
+ else:
1605
+ latents = latents / self.vae.config.scaling_factor
1606
+
1607
+ image = self.vae.decode(latents, return_dict=False)[0]
1608
+
1609
+ # cast back to fp16 if needed
1610
+ if needs_upcasting:
1611
+ self.vae.to(dtype=torch.float16)
1612
+ else:
1613
+ image = latents
1614
+
1615
+ if not output_type == "latent":
1616
+ # apply watermark if available
1617
+ if self.watermark is not None:
1618
+ image = self.watermark.apply_watermark(image)
1619
+
1620
+ image = self.image_processor.postprocess(image, output_type=output_type)
1621
+
1622
+ # Offload all models
1623
+ self.maybe_free_model_hooks()
1624
+
1625
+ if self.do_perturbed_attention_guidance:
1626
+ self.unet.set_attn_processor(original_attn_proc)
1627
+
1628
+ if not return_dict:
1629
+ return (image,)
1630
+
1631
+ return StableDiffusionXLPipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-310.pyc ADDED
Binary file (20 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_upscale.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_output.cpython-310.pyc ADDED
Binary file (2.02 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc ADDED
Binary file (36.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc ADDED
Binary file (28.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc ADDED
Binary file (16.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc ADDED
Binary file (39.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc ADDED
Binary file (44.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc ADDED
Binary file (29.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc ADDED
Binary file (20.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc ADDED
Binary file (25.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc ADDED
Binary file (24.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker_flax.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _additional_imports = {}
16
+ _import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]}
17
+
18
+ try:
19
+ if not (is_transformers_available() and is_torch_available()):
20
+ raise OptionalDependencyNotAvailable()
21
+ except OptionalDependencyNotAvailable:
22
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
23
+
24
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
25
+ else:
26
+ _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"]
27
+ _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
28
+ _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"]
29
+
30
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
+ try:
32
+ if not (is_transformers_available() and is_torch_available()):
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
36
+ else:
37
+ from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
38
+ from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
39
+ from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline
40
+
41
+ else:
42
+ import sys
43
+
44
+ sys.modules[__name__] = _LazyModule(
45
+ __name__,
46
+ globals()["__file__"],
47
+ _import_structure,
48
+ module_spec=__spec__,
49
+ )
50
+
51
+ for name, value in _dummy_objects.items():
52
+ setattr(sys.modules[__name__], name, value)
53
+ for name, value in _additional_imports.items():
54
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_output.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3.cpython-310.pyc ADDED
Binary file (38.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3_img2img.cpython-310.pyc ADDED
Binary file (38.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/__pycache__/pipeline_stable_diffusion_3_inpaint.cpython-310.pyc ADDED
Binary file (45.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class StableDiffusion3PipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from ...models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import SD3Transformer2DModel
33
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
34
+ from ...utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from ...utils.torch_utils import randn_tensor
43
+ from ..pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import StableDiffusion3PipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import StableDiffusion3Pipeline
62
+
63
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
64
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+ >>> prompt = "A cat holding a sign that says hello world"
68
+ >>> image = pipe(prompt).images[0]
69
+ >>> image.save("sd3.png")
70
+ ```
71
+ """
72
+
73
+
74
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
75
+ def calculate_shift(
76
+ image_seq_len,
77
+ base_seq_len: int = 256,
78
+ max_seq_len: int = 4096,
79
+ base_shift: float = 0.5,
80
+ max_shift: float = 1.15,
81
+ ):
82
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
83
+ b = base_shift - m * base_seq_len
84
+ mu = image_seq_len * m + b
85
+ return mu
86
+
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
89
+ def retrieve_timesteps(
90
+ scheduler,
91
+ num_inference_steps: Optional[int] = None,
92
+ device: Optional[Union[str, torch.device]] = None,
93
+ timesteps: Optional[List[int]] = None,
94
+ sigmas: Optional[List[float]] = None,
95
+ **kwargs,
96
+ ):
97
+ r"""
98
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
99
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
100
+
101
+ Args:
102
+ scheduler (`SchedulerMixin`):
103
+ The scheduler to get timesteps from.
104
+ num_inference_steps (`int`):
105
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
106
+ must be `None`.
107
+ device (`str` or `torch.device`, *optional*):
108
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
109
+ timesteps (`List[int]`, *optional*):
110
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
111
+ `num_inference_steps` and `sigmas` must be `None`.
112
+ sigmas (`List[float]`, *optional*):
113
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
114
+ `num_inference_steps` and `timesteps` must be `None`.
115
+
116
+ Returns:
117
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
118
+ second element is the number of inference steps.
119
+ """
120
+ if timesteps is not None and sigmas is not None:
121
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
122
+ if timesteps is not None:
123
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
124
+ if not accepts_timesteps:
125
+ raise ValueError(
126
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
127
+ f" timestep schedules. Please check whether you are using the correct scheduler."
128
+ )
129
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ num_inference_steps = len(timesteps)
132
+ elif sigmas is not None:
133
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
134
+ if not accept_sigmas:
135
+ raise ValueError(
136
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
137
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
138
+ )
139
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
140
+ timesteps = scheduler.timesteps
141
+ num_inference_steps = len(timesteps)
142
+ else:
143
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ return timesteps, num_inference_steps
146
+
147
+
148
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
149
+ r"""
150
+ Args:
151
+ transformer ([`SD3Transformer2DModel`]):
152
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
153
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
154
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
155
+ vae ([`AutoencoderKL`]):
156
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
157
+ text_encoder ([`CLIPTextModelWithProjection`]):
158
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
159
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
160
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
161
+ as its dimension.
162
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
163
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
164
+ specifically the
165
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
166
+ variant.
167
+ text_encoder_3 ([`T5EncoderModel`]):
168
+ Frozen text-encoder. Stable Diffusion 3 uses
169
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
170
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
171
+ tokenizer (`CLIPTokenizer`):
172
+ Tokenizer of class
173
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
174
+ tokenizer_2 (`CLIPTokenizer`):
175
+ Second Tokenizer of class
176
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
177
+ tokenizer_3 (`T5TokenizerFast`):
178
+ Tokenizer of class
179
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
180
+ image_encoder (`SiglipVisionModel`, *optional*):
181
+ Pre-trained Vision Model for IP Adapter.
182
+ feature_extractor (`SiglipImageProcessor`, *optional*):
183
+ Image processor for IP Adapter.
184
+ """
185
+
186
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
187
+ _optional_components = ["image_encoder", "feature_extractor"]
188
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
189
+
190
+ def __init__(
191
+ self,
192
+ transformer: SD3Transformer2DModel,
193
+ scheduler: FlowMatchEulerDiscreteScheduler,
194
+ vae: AutoencoderKL,
195
+ text_encoder: CLIPTextModelWithProjection,
196
+ tokenizer: CLIPTokenizer,
197
+ text_encoder_2: CLIPTextModelWithProjection,
198
+ tokenizer_2: CLIPTokenizer,
199
+ text_encoder_3: T5EncoderModel,
200
+ tokenizer_3: T5TokenizerFast,
201
+ image_encoder: SiglipVisionModel = None,
202
+ feature_extractor: SiglipImageProcessor = None,
203
+ ):
204
+ super().__init__()
205
+
206
+ self.register_modules(
207
+ vae=vae,
208
+ text_encoder=text_encoder,
209
+ text_encoder_2=text_encoder_2,
210
+ text_encoder_3=text_encoder_3,
211
+ tokenizer=tokenizer,
212
+ tokenizer_2=tokenizer_2,
213
+ tokenizer_3=tokenizer_3,
214
+ transformer=transformer,
215
+ scheduler=scheduler,
216
+ image_encoder=image_encoder,
217
+ feature_extractor=feature_extractor,
218
+ )
219
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
220
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
221
+ self.tokenizer_max_length = (
222
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
223
+ )
224
+ self.default_sample_size = (
225
+ self.transformer.config.sample_size
226
+ if hasattr(self, "transformer") and self.transformer is not None
227
+ else 128
228
+ )
229
+ self.patch_size = (
230
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
231
+ )
232
+
233
+ def _get_t5_prompt_embeds(
234
+ self,
235
+ prompt: Union[str, List[str]] = None,
236
+ num_images_per_prompt: int = 1,
237
+ max_sequence_length: int = 256,
238
+ device: Optional[torch.device] = None,
239
+ dtype: Optional[torch.dtype] = None,
240
+ ):
241
+ device = device or self._execution_device
242
+ dtype = dtype or self.text_encoder.dtype
243
+
244
+ prompt = [prompt] if isinstance(prompt, str) else prompt
245
+ batch_size = len(prompt)
246
+
247
+ if self.text_encoder_3 is None:
248
+ return torch.zeros(
249
+ (
250
+ batch_size * num_images_per_prompt,
251
+ self.tokenizer_max_length,
252
+ self.transformer.config.joint_attention_dim,
253
+ ),
254
+ device=device,
255
+ dtype=dtype,
256
+ )
257
+
258
+ text_inputs = self.tokenizer_3(
259
+ prompt,
260
+ padding="max_length",
261
+ max_length=max_sequence_length,
262
+ truncation=True,
263
+ add_special_tokens=True,
264
+ return_tensors="pt",
265
+ )
266
+ text_input_ids = text_inputs.input_ids
267
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
268
+
269
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
270
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
271
+ logger.warning(
272
+ "The following part of your input was truncated because `max_sequence_length` is set to "
273
+ f" {max_sequence_length} tokens: {removed_text}"
274
+ )
275
+
276
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
277
+
278
+ dtype = self.text_encoder_3.dtype
279
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
280
+
281
+ _, seq_len, _ = prompt_embeds.shape
282
+
283
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
284
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
286
+
287
+ return prompt_embeds
288
+
289
+ def _get_clip_prompt_embeds(
290
+ self,
291
+ prompt: Union[str, List[str]],
292
+ num_images_per_prompt: int = 1,
293
+ device: Optional[torch.device] = None,
294
+ clip_skip: Optional[int] = None,
295
+ clip_model_index: int = 0,
296
+ ):
297
+ device = device or self._execution_device
298
+
299
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
300
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
301
+
302
+ tokenizer = clip_tokenizers[clip_model_index]
303
+ text_encoder = clip_text_encoders[clip_model_index]
304
+
305
+ prompt = [prompt] if isinstance(prompt, str) else prompt
306
+ batch_size = len(prompt)
307
+
308
+ text_inputs = tokenizer(
309
+ prompt,
310
+ padding="max_length",
311
+ max_length=self.tokenizer_max_length,
312
+ truncation=True,
313
+ return_tensors="pt",
314
+ )
315
+
316
+ text_input_ids = text_inputs.input_ids
317
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
318
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
319
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
320
+ logger.warning(
321
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
322
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
323
+ )
324
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
325
+ pooled_prompt_embeds = prompt_embeds[0]
326
+
327
+ if clip_skip is None:
328
+ prompt_embeds = prompt_embeds.hidden_states[-2]
329
+ else:
330
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
331
+
332
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
333
+
334
+ _, seq_len, _ = prompt_embeds.shape
335
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
338
+
339
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
340
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341
+
342
+ return prompt_embeds, pooled_prompt_embeds
343
+
344
+ def encode_prompt(
345
+ self,
346
+ prompt: Union[str, List[str]],
347
+ prompt_2: Union[str, List[str]],
348
+ prompt_3: Union[str, List[str]],
349
+ device: Optional[torch.device] = None,
350
+ num_images_per_prompt: int = 1,
351
+ do_classifier_free_guidance: bool = True,
352
+ negative_prompt: Optional[Union[str, List[str]]] = None,
353
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
354
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
355
+ prompt_embeds: Optional[torch.FloatTensor] = None,
356
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
357
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
358
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
359
+ clip_skip: Optional[int] = None,
360
+ max_sequence_length: int = 256,
361
+ lora_scale: Optional[float] = None,
362
+ ):
363
+ r"""
364
+
365
+ Args:
366
+ prompt (`str` or `List[str]`, *optional*):
367
+ prompt to be encoded
368
+ prompt_2 (`str` or `List[str]`, *optional*):
369
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
370
+ used in all text-encoders
371
+ prompt_3 (`str` or `List[str]`, *optional*):
372
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
373
+ used in all text-encoders
374
+ device: (`torch.device`):
375
+ torch device
376
+ num_images_per_prompt (`int`):
377
+ number of images that should be generated per prompt
378
+ do_classifier_free_guidance (`bool`):
379
+ whether to use classifier free guidance or not
380
+ negative_prompt (`str` or `List[str]`, *optional*):
381
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
382
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
383
+ less than `1`).
384
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
385
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
386
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
387
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
388
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
389
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
390
+ prompt_embeds (`torch.FloatTensor`, *optional*):
391
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
392
+ provided, text embeddings will be generated from `prompt` input argument.
393
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
394
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
395
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
396
+ argument.
397
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
398
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
399
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
400
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
401
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
402
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
403
+ input argument.
404
+ clip_skip (`int`, *optional*):
405
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
406
+ the output of the pre-final layer will be used for computing the prompt embeddings.
407
+ lora_scale (`float`, *optional*):
408
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
409
+ """
410
+ device = device or self._execution_device
411
+
412
+ # set lora scale so that monkey patched LoRA
413
+ # function of text encoder can correctly access it
414
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
415
+ self._lora_scale = lora_scale
416
+
417
+ # dynamically adjust the LoRA scale
418
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
419
+ scale_lora_layers(self.text_encoder, lora_scale)
420
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
421
+ scale_lora_layers(self.text_encoder_2, lora_scale)
422
+
423
+ prompt = [prompt] if isinstance(prompt, str) else prompt
424
+ if prompt is not None:
425
+ batch_size = len(prompt)
426
+ else:
427
+ batch_size = prompt_embeds.shape[0]
428
+
429
+ if prompt_embeds is None:
430
+ prompt_2 = prompt_2 or prompt
431
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
432
+
433
+ prompt_3 = prompt_3 or prompt
434
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
435
+
436
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
437
+ prompt=prompt,
438
+ device=device,
439
+ num_images_per_prompt=num_images_per_prompt,
440
+ clip_skip=clip_skip,
441
+ clip_model_index=0,
442
+ )
443
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
444
+ prompt=prompt_2,
445
+ device=device,
446
+ num_images_per_prompt=num_images_per_prompt,
447
+ clip_skip=clip_skip,
448
+ clip_model_index=1,
449
+ )
450
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
451
+
452
+ t5_prompt_embed = self._get_t5_prompt_embeds(
453
+ prompt=prompt_3,
454
+ num_images_per_prompt=num_images_per_prompt,
455
+ max_sequence_length=max_sequence_length,
456
+ device=device,
457
+ )
458
+
459
+ clip_prompt_embeds = torch.nn.functional.pad(
460
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
461
+ )
462
+
463
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
464
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
465
+
466
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
467
+ negative_prompt = negative_prompt or ""
468
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
469
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
470
+
471
+ # normalize str to list
472
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
473
+ negative_prompt_2 = (
474
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
475
+ )
476
+ negative_prompt_3 = (
477
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
478
+ )
479
+
480
+ if prompt is not None and type(prompt) is not type(negative_prompt):
481
+ raise TypeError(
482
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
483
+ f" {type(prompt)}."
484
+ )
485
+ elif batch_size != len(negative_prompt):
486
+ raise ValueError(
487
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
488
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
489
+ " the batch size of `prompt`."
490
+ )
491
+
492
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
493
+ negative_prompt,
494
+ device=device,
495
+ num_images_per_prompt=num_images_per_prompt,
496
+ clip_skip=None,
497
+ clip_model_index=0,
498
+ )
499
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
500
+ negative_prompt_2,
501
+ device=device,
502
+ num_images_per_prompt=num_images_per_prompt,
503
+ clip_skip=None,
504
+ clip_model_index=1,
505
+ )
506
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
507
+
508
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
509
+ prompt=negative_prompt_3,
510
+ num_images_per_prompt=num_images_per_prompt,
511
+ max_sequence_length=max_sequence_length,
512
+ device=device,
513
+ )
514
+
515
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
516
+ negative_clip_prompt_embeds,
517
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
518
+ )
519
+
520
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
521
+ negative_pooled_prompt_embeds = torch.cat(
522
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
523
+ )
524
+
525
+ if self.text_encoder is not None:
526
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
527
+ # Retrieve the original scale by scaling back the LoRA layers
528
+ unscale_lora_layers(self.text_encoder, lora_scale)
529
+
530
+ if self.text_encoder_2 is not None:
531
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
532
+ # Retrieve the original scale by scaling back the LoRA layers
533
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
534
+
535
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
536
+
537
+ def check_inputs(
538
+ self,
539
+ prompt,
540
+ prompt_2,
541
+ prompt_3,
542
+ height,
543
+ width,
544
+ negative_prompt=None,
545
+ negative_prompt_2=None,
546
+ negative_prompt_3=None,
547
+ prompt_embeds=None,
548
+ negative_prompt_embeds=None,
549
+ pooled_prompt_embeds=None,
550
+ negative_pooled_prompt_embeds=None,
551
+ callback_on_step_end_tensor_inputs=None,
552
+ max_sequence_length=None,
553
+ ):
554
+ if (
555
+ height % (self.vae_scale_factor * self.patch_size) != 0
556
+ or width % (self.vae_scale_factor * self.patch_size) != 0
557
+ ):
558
+ raise ValueError(
559
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
560
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
561
+ )
562
+
563
+ if callback_on_step_end_tensor_inputs is not None and not all(
564
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
565
+ ):
566
+ raise ValueError(
567
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
568
+ )
569
+
570
+ if prompt is not None and prompt_embeds is not None:
571
+ raise ValueError(
572
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
573
+ " only forward one of the two."
574
+ )
575
+ elif prompt_2 is not None and prompt_embeds is not None:
576
+ raise ValueError(
577
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
578
+ " only forward one of the two."
579
+ )
580
+ elif prompt_3 is not None and prompt_embeds is not None:
581
+ raise ValueError(
582
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
583
+ " only forward one of the two."
584
+ )
585
+ elif prompt is None and prompt_embeds is None:
586
+ raise ValueError(
587
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
588
+ )
589
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
590
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
591
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
592
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
593
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
594
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
595
+
596
+ if negative_prompt is not None and negative_prompt_embeds is not None:
597
+ raise ValueError(
598
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
599
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
600
+ )
601
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
604
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
605
+ )
606
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
609
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
610
+ )
611
+
612
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
613
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
614
+ raise ValueError(
615
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
616
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
617
+ f" {negative_prompt_embeds.shape}."
618
+ )
619
+
620
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
621
+ raise ValueError(
622
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
623
+ )
624
+
625
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
626
+ raise ValueError(
627
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
628
+ )
629
+
630
+ if max_sequence_length is not None and max_sequence_length > 512:
631
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
632
+
633
+ def prepare_latents(
634
+ self,
635
+ batch_size,
636
+ num_channels_latents,
637
+ height,
638
+ width,
639
+ dtype,
640
+ device,
641
+ generator,
642
+ latents=None,
643
+ ):
644
+ if latents is not None:
645
+ return latents.to(device=device, dtype=dtype)
646
+
647
+ shape = (
648
+ batch_size,
649
+ num_channels_latents,
650
+ int(height) // self.vae_scale_factor,
651
+ int(width) // self.vae_scale_factor,
652
+ )
653
+
654
+ if isinstance(generator, list) and len(generator) != batch_size:
655
+ raise ValueError(
656
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
657
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
658
+ )
659
+
660
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
661
+
662
+ return latents
663
+
664
+ @property
665
+ def guidance_scale(self):
666
+ return self._guidance_scale
667
+
668
+ @property
669
+ def skip_guidance_layers(self):
670
+ return self._skip_guidance_layers
671
+
672
+ @property
673
+ def clip_skip(self):
674
+ return self._clip_skip
675
+
676
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
677
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
678
+ # corresponds to doing no classifier free guidance.
679
+ @property
680
+ def do_classifier_free_guidance(self):
681
+ return self._guidance_scale > 1
682
+
683
+ @property
684
+ def joint_attention_kwargs(self):
685
+ return self._joint_attention_kwargs
686
+
687
+ @property
688
+ def num_timesteps(self):
689
+ return self._num_timesteps
690
+
691
+ @property
692
+ def interrupt(self):
693
+ return self._interrupt
694
+
695
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
696
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
697
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
698
+
699
+ Args:
700
+ image (`PipelineImageInput`):
701
+ Input image to be encoded.
702
+ device: (`torch.device`):
703
+ Torch device.
704
+
705
+ Returns:
706
+ `torch.Tensor`: The encoded image feature representation.
707
+ """
708
+ if not isinstance(image, torch.Tensor):
709
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
710
+
711
+ image = image.to(device=device, dtype=self.dtype)
712
+
713
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
714
+
715
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
716
+ def prepare_ip_adapter_image_embeds(
717
+ self,
718
+ ip_adapter_image: Optional[PipelineImageInput] = None,
719
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
720
+ device: Optional[torch.device] = None,
721
+ num_images_per_prompt: int = 1,
722
+ do_classifier_free_guidance: bool = True,
723
+ ) -> torch.Tensor:
724
+ """Prepares image embeddings for use in the IP-Adapter.
725
+
726
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
727
+
728
+ Args:
729
+ ip_adapter_image (`PipelineImageInput`, *optional*):
730
+ The input image to extract features from for IP-Adapter.
731
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
732
+ Precomputed image embeddings.
733
+ device: (`torch.device`, *optional*):
734
+ Torch device.
735
+ num_images_per_prompt (`int`, defaults to 1):
736
+ Number of images that should be generated per prompt.
737
+ do_classifier_free_guidance (`bool`, defaults to True):
738
+ Whether to use classifier free guidance or not.
739
+ """
740
+ device = device or self._execution_device
741
+
742
+ if ip_adapter_image_embeds is not None:
743
+ if do_classifier_free_guidance:
744
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
745
+ else:
746
+ single_image_embeds = ip_adapter_image_embeds
747
+ elif ip_adapter_image is not None:
748
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
749
+ if do_classifier_free_guidance:
750
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
751
+ else:
752
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
753
+
754
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
755
+
756
+ if do_classifier_free_guidance:
757
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
758
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
759
+
760
+ return image_embeds.to(device=device)
761
+
762
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
763
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
764
+ logger.warning(
765
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
766
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
767
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
768
+ )
769
+
770
+ super().enable_sequential_cpu_offload(*args, **kwargs)
771
+
772
+ @torch.no_grad()
773
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
774
+ def __call__(
775
+ self,
776
+ prompt: Union[str, List[str]] = None,
777
+ prompt_2: Optional[Union[str, List[str]]] = None,
778
+ prompt_3: Optional[Union[str, List[str]]] = None,
779
+ height: Optional[int] = None,
780
+ width: Optional[int] = None,
781
+ num_inference_steps: int = 28,
782
+ sigmas: Optional[List[float]] = None,
783
+ guidance_scale: float = 7.0,
784
+ negative_prompt: Optional[Union[str, List[str]]] = None,
785
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
786
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
787
+ num_images_per_prompt: Optional[int] = 1,
788
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
789
+ latents: Optional[torch.FloatTensor] = None,
790
+ prompt_embeds: Optional[torch.FloatTensor] = None,
791
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
792
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
793
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
794
+ ip_adapter_image: Optional[PipelineImageInput] = None,
795
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
796
+ output_type: Optional[str] = "pil",
797
+ return_dict: bool = True,
798
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
799
+ clip_skip: Optional[int] = None,
800
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
801
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
802
+ max_sequence_length: int = 256,
803
+ skip_guidance_layers: List[int] = None,
804
+ skip_layer_guidance_scale: float = 2.8,
805
+ skip_layer_guidance_stop: float = 0.2,
806
+ skip_layer_guidance_start: float = 0.01,
807
+ mu: Optional[float] = None,
808
+ ):
809
+ r"""
810
+ Function invoked when calling the pipeline for generation.
811
+
812
+ Args:
813
+ prompt (`str` or `List[str]`, *optional*):
814
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
815
+ instead.
816
+ prompt_2 (`str` or `List[str]`, *optional*):
817
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
818
+ will be used instead
819
+ prompt_3 (`str` or `List[str]`, *optional*):
820
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
821
+ will be used instead
822
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
823
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
824
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
825
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
826
+ num_inference_steps (`int`, *optional*, defaults to 50):
827
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
828
+ expense of slower inference.
829
+ sigmas (`List[float]`, *optional*):
830
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
831
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
832
+ will be used.
833
+ guidance_scale (`float`, *optional*, defaults to 7.0):
834
+ Guidance scale as defined in [Classifier-Free Diffusion
835
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
836
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
837
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
838
+ the text `prompt`, usually at the expense of lower image quality.
839
+ negative_prompt (`str` or `List[str]`, *optional*):
840
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
841
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
842
+ less than `1`).
843
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
844
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
845
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
846
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
847
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
848
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
849
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
850
+ The number of images to generate per prompt.
851
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
852
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
853
+ to make generation deterministic.
854
+ latents (`torch.FloatTensor`, *optional*):
855
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
856
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
857
+ tensor will be generated by sampling using the supplied random `generator`.
858
+ prompt_embeds (`torch.FloatTensor`, *optional*):
859
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
860
+ provided, text embeddings will be generated from `prompt` input argument.
861
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
862
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
863
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
864
+ argument.
865
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
866
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
867
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
868
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
869
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
870
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
871
+ input argument.
872
+ ip_adapter_image (`PipelineImageInput`, *optional*):
873
+ Optional image input to work with IP Adapters.
874
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
875
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
876
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
877
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
878
+ output_type (`str`, *optional*, defaults to `"pil"`):
879
+ The output format of the generate image. Choose between
880
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
881
+ return_dict (`bool`, *optional*, defaults to `True`):
882
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
883
+ a plain tuple.
884
+ joint_attention_kwargs (`dict`, *optional*):
885
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
886
+ `self.processor` in
887
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
888
+ callback_on_step_end (`Callable`, *optional*):
889
+ A function that calls at the end of each denoising steps during the inference. The function is called
890
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
891
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
892
+ `callback_on_step_end_tensor_inputs`.
893
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
894
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
895
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
896
+ `._callback_tensor_inputs` attribute of your pipeline class.
897
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
898
+ skip_guidance_layers (`List[int]`, *optional*):
899
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
900
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
901
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
902
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
903
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
904
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
905
+ with a scale of `1`.
906
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
907
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
908
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
909
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
910
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
911
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
912
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
913
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
914
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
915
+
916
+ Examples:
917
+
918
+ Returns:
919
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
920
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
921
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
922
+ """
923
+
924
+ height = height or self.default_sample_size * self.vae_scale_factor
925
+ width = width or self.default_sample_size * self.vae_scale_factor
926
+
927
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
928
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
929
+
930
+ # 1. Check inputs. Raise error if not correct
931
+ self.check_inputs(
932
+ prompt,
933
+ prompt_2,
934
+ prompt_3,
935
+ height,
936
+ width,
937
+ negative_prompt=negative_prompt,
938
+ negative_prompt_2=negative_prompt_2,
939
+ negative_prompt_3=negative_prompt_3,
940
+ prompt_embeds=prompt_embeds,
941
+ negative_prompt_embeds=negative_prompt_embeds,
942
+ pooled_prompt_embeds=pooled_prompt_embeds,
943
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
944
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
945
+ max_sequence_length=max_sequence_length,
946
+ )
947
+
948
+ self._guidance_scale = guidance_scale
949
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
950
+ self._clip_skip = clip_skip
951
+ self._joint_attention_kwargs = joint_attention_kwargs
952
+ self._interrupt = False
953
+
954
+ # 2. Define call parameters
955
+ if prompt is not None and isinstance(prompt, str):
956
+ batch_size = 1
957
+ elif prompt is not None and isinstance(prompt, list):
958
+ batch_size = len(prompt)
959
+ else:
960
+ batch_size = prompt_embeds.shape[0]
961
+
962
+ device = self._execution_device
963
+
964
+ lora_scale = (
965
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
966
+ )
967
+ (
968
+ prompt_embeds,
969
+ negative_prompt_embeds,
970
+ pooled_prompt_embeds,
971
+ negative_pooled_prompt_embeds,
972
+ ) = self.encode_prompt(
973
+ prompt=prompt,
974
+ prompt_2=prompt_2,
975
+ prompt_3=prompt_3,
976
+ negative_prompt=negative_prompt,
977
+ negative_prompt_2=negative_prompt_2,
978
+ negative_prompt_3=negative_prompt_3,
979
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
980
+ prompt_embeds=prompt_embeds,
981
+ negative_prompt_embeds=negative_prompt_embeds,
982
+ pooled_prompt_embeds=pooled_prompt_embeds,
983
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
984
+ device=device,
985
+ clip_skip=self.clip_skip,
986
+ num_images_per_prompt=num_images_per_prompt,
987
+ max_sequence_length=max_sequence_length,
988
+ lora_scale=lora_scale,
989
+ )
990
+
991
+ if self.do_classifier_free_guidance:
992
+ if skip_guidance_layers is not None:
993
+ original_prompt_embeds = prompt_embeds
994
+ original_pooled_prompt_embeds = pooled_prompt_embeds
995
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
996
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
997
+
998
+ # 4. Prepare latent variables
999
+ num_channels_latents = self.transformer.config.in_channels
1000
+ latents = self.prepare_latents(
1001
+ batch_size * num_images_per_prompt,
1002
+ num_channels_latents,
1003
+ height,
1004
+ width,
1005
+ prompt_embeds.dtype,
1006
+ device,
1007
+ generator,
1008
+ latents,
1009
+ )
1010
+
1011
+ # 5. Prepare timesteps
1012
+ scheduler_kwargs = {}
1013
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1014
+ _, _, height, width = latents.shape
1015
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1016
+ width // self.transformer.config.patch_size
1017
+ )
1018
+ mu = calculate_shift(
1019
+ image_seq_len,
1020
+ self.scheduler.config.get("base_image_seq_len", 256),
1021
+ self.scheduler.config.get("max_image_seq_len", 4096),
1022
+ self.scheduler.config.get("base_shift", 0.5),
1023
+ self.scheduler.config.get("max_shift", 1.16),
1024
+ )
1025
+ scheduler_kwargs["mu"] = mu
1026
+ elif mu is not None:
1027
+ scheduler_kwargs["mu"] = mu
1028
+ timesteps, num_inference_steps = retrieve_timesteps(
1029
+ self.scheduler,
1030
+ num_inference_steps,
1031
+ device,
1032
+ sigmas=sigmas,
1033
+ **scheduler_kwargs,
1034
+ )
1035
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1036
+ self._num_timesteps = len(timesteps)
1037
+
1038
+ # 6. Prepare image embeddings
1039
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1040
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1041
+ ip_adapter_image,
1042
+ ip_adapter_image_embeds,
1043
+ device,
1044
+ batch_size * num_images_per_prompt,
1045
+ self.do_classifier_free_guidance,
1046
+ )
1047
+
1048
+ if self.joint_attention_kwargs is None:
1049
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1050
+ else:
1051
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1052
+
1053
+ # 7. Denoising loop
1054
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
+ for i, t in enumerate(timesteps):
1056
+ if self.interrupt:
1057
+ continue
1058
+
1059
+ # expand the latents if we are doing classifier free guidance
1060
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1061
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1062
+ timestep = t.expand(latent_model_input.shape[0])
1063
+
1064
+ noise_pred = self.transformer(
1065
+ hidden_states=latent_model_input,
1066
+ timestep=timestep,
1067
+ encoder_hidden_states=prompt_embeds,
1068
+ pooled_projections=pooled_prompt_embeds,
1069
+ joint_attention_kwargs=self.joint_attention_kwargs,
1070
+ return_dict=False,
1071
+ )[0]
1072
+
1073
+ # perform guidance
1074
+ if self.do_classifier_free_guidance:
1075
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1076
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1077
+ should_skip_layers = (
1078
+ True
1079
+ if i > num_inference_steps * skip_layer_guidance_start
1080
+ and i < num_inference_steps * skip_layer_guidance_stop
1081
+ else False
1082
+ )
1083
+ if skip_guidance_layers is not None and should_skip_layers:
1084
+ timestep = t.expand(latents.shape[0])
1085
+ latent_model_input = latents
1086
+ noise_pred_skip_layers = self.transformer(
1087
+ hidden_states=latent_model_input,
1088
+ timestep=timestep,
1089
+ encoder_hidden_states=original_prompt_embeds,
1090
+ pooled_projections=original_pooled_prompt_embeds,
1091
+ joint_attention_kwargs=self.joint_attention_kwargs,
1092
+ return_dict=False,
1093
+ skip_layers=skip_guidance_layers,
1094
+ )[0]
1095
+ noise_pred = (
1096
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1097
+ )
1098
+
1099
+ # compute the previous noisy sample x_t -> x_t-1
1100
+ latents_dtype = latents.dtype
1101
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1102
+
1103
+ if latents.dtype != latents_dtype:
1104
+ if torch.backends.mps.is_available():
1105
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1106
+ latents = latents.to(latents_dtype)
1107
+
1108
+ if callback_on_step_end is not None:
1109
+ callback_kwargs = {}
1110
+ for k in callback_on_step_end_tensor_inputs:
1111
+ callback_kwargs[k] = locals()[k]
1112
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1113
+
1114
+ latents = callback_outputs.pop("latents", latents)
1115
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1116
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
1117
+
1118
+ # call the callback, if provided
1119
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1120
+ progress_bar.update()
1121
+
1122
+ if XLA_AVAILABLE:
1123
+ xm.mark_step()
1124
+
1125
+ if output_type == "latent":
1126
+ image = latents
1127
+
1128
+ else:
1129
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1130
+
1131
+ image = self.vae.decode(latents, return_dict=False)[0]
1132
+ image = self.image_processor.postprocess(image, output_type=output_type)
1133
+
1134
+ # Offload all models
1135
+ self.maybe_free_model_hooks()
1136
+
1137
+ if not return_dict:
1138
+ return (image,)
1139
+
1140
+ return StableDiffusion3PipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import PIL.Image
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModelWithProjection,
22
+ CLIPTokenizer,
23
+ SiglipImageProcessor,
24
+ SiglipVisionModel,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from ...models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import SD3Transformer2DModel
33
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
34
+ from ...utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from ...utils.torch_utils import randn_tensor
43
+ from ..pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import StableDiffusion3PipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+
62
+ >>> from diffusers import AutoPipelineForImage2Image
63
+ >>> from diffusers.utils import load_image
64
+
65
+ >>> device = "cuda"
66
+ >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
67
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
68
+ >>> pipe = pipe.to(device)
69
+
70
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
71
+ >>> init_image = load_image(url).resize((1024, 1024))
72
+
73
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
74
+
75
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
76
+ ```
77
+ """
78
+
79
+
80
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
81
+ def calculate_shift(
82
+ image_seq_len,
83
+ base_seq_len: int = 256,
84
+ max_seq_len: int = 4096,
85
+ base_shift: float = 0.5,
86
+ max_shift: float = 1.15,
87
+ ):
88
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
89
+ b = base_shift - m * base_seq_len
90
+ mu = image_seq_len * m + b
91
+ return mu
92
+
93
+
94
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
95
+ def retrieve_latents(
96
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
97
+ ):
98
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
99
+ return encoder_output.latent_dist.sample(generator)
100
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
101
+ return encoder_output.latent_dist.mode()
102
+ elif hasattr(encoder_output, "latents"):
103
+ return encoder_output.latents
104
+ else:
105
+ raise AttributeError("Could not access latents of provided encoder_output")
106
+
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
109
+ def retrieve_timesteps(
110
+ scheduler,
111
+ num_inference_steps: Optional[int] = None,
112
+ device: Optional[Union[str, torch.device]] = None,
113
+ timesteps: Optional[List[int]] = None,
114
+ sigmas: Optional[List[float]] = None,
115
+ **kwargs,
116
+ ):
117
+ r"""
118
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
119
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
120
+
121
+ Args:
122
+ scheduler (`SchedulerMixin`):
123
+ The scheduler to get timesteps from.
124
+ num_inference_steps (`int`):
125
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
126
+ must be `None`.
127
+ device (`str` or `torch.device`, *optional*):
128
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
129
+ timesteps (`List[int]`, *optional*):
130
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
131
+ `num_inference_steps` and `sigmas` must be `None`.
132
+ sigmas (`List[float]`, *optional*):
133
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
134
+ `num_inference_steps` and `timesteps` must be `None`.
135
+
136
+ Returns:
137
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
138
+ second element is the number of inference steps.
139
+ """
140
+ if timesteps is not None and sigmas is not None:
141
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
142
+ if timesteps is not None:
143
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
144
+ if not accepts_timesteps:
145
+ raise ValueError(
146
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
147
+ f" timestep schedules. Please check whether you are using the correct scheduler."
148
+ )
149
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
150
+ timesteps = scheduler.timesteps
151
+ num_inference_steps = len(timesteps)
152
+ elif sigmas is not None:
153
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
154
+ if not accept_sigmas:
155
+ raise ValueError(
156
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
157
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
158
+ )
159
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
160
+ timesteps = scheduler.timesteps
161
+ num_inference_steps = len(timesteps)
162
+ else:
163
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
164
+ timesteps = scheduler.timesteps
165
+ return timesteps, num_inference_steps
166
+
167
+
168
+ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
169
+ r"""
170
+ Args:
171
+ transformer ([`SD3Transformer2DModel`]):
172
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
173
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
175
+ vae ([`AutoencoderKL`]):
176
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
177
+ text_encoder ([`CLIPTextModelWithProjection`]):
178
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
179
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
180
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
181
+ as its dimension.
182
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
183
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
184
+ specifically the
185
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
186
+ variant.
187
+ text_encoder_3 ([`T5EncoderModel`]):
188
+ Frozen text-encoder. Stable Diffusion 3 uses
189
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
190
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
191
+ tokenizer (`CLIPTokenizer`):
192
+ Tokenizer of class
193
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
194
+ tokenizer_2 (`CLIPTokenizer`):
195
+ Second Tokenizer of class
196
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
197
+ tokenizer_3 (`T5TokenizerFast`):
198
+ Tokenizer of class
199
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
200
+ image_encoder (`SiglipVisionModel`, *optional*):
201
+ Pre-trained Vision Model for IP Adapter.
202
+ feature_extractor (`SiglipImageProcessor`, *optional*):
203
+ Image processor for IP Adapter.
204
+ """
205
+
206
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
207
+ _optional_components = ["image_encoder", "feature_extractor"]
208
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
209
+
210
+ def __init__(
211
+ self,
212
+ transformer: SD3Transformer2DModel,
213
+ scheduler: FlowMatchEulerDiscreteScheduler,
214
+ vae: AutoencoderKL,
215
+ text_encoder: CLIPTextModelWithProjection,
216
+ tokenizer: CLIPTokenizer,
217
+ text_encoder_2: CLIPTextModelWithProjection,
218
+ tokenizer_2: CLIPTokenizer,
219
+ text_encoder_3: T5EncoderModel,
220
+ tokenizer_3: T5TokenizerFast,
221
+ image_encoder: Optional[SiglipVisionModel] = None,
222
+ feature_extractor: Optional[SiglipImageProcessor] = None,
223
+ ):
224
+ super().__init__()
225
+
226
+ self.register_modules(
227
+ vae=vae,
228
+ text_encoder=text_encoder,
229
+ text_encoder_2=text_encoder_2,
230
+ text_encoder_3=text_encoder_3,
231
+ tokenizer=tokenizer,
232
+ tokenizer_2=tokenizer_2,
233
+ tokenizer_3=tokenizer_3,
234
+ transformer=transformer,
235
+ scheduler=scheduler,
236
+ image_encoder=image_encoder,
237
+ feature_extractor=feature_extractor,
238
+ )
239
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
240
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
241
+ self.image_processor = VaeImageProcessor(
242
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
243
+ )
244
+ self.tokenizer_max_length = (
245
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
246
+ )
247
+ self.default_sample_size = (
248
+ self.transformer.config.sample_size
249
+ if hasattr(self, "transformer") and self.transformer is not None
250
+ else 128
251
+ )
252
+ self.patch_size = (
253
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
254
+ )
255
+
256
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
257
+ def _get_t5_prompt_embeds(
258
+ self,
259
+ prompt: Union[str, List[str]] = None,
260
+ num_images_per_prompt: int = 1,
261
+ max_sequence_length: int = 256,
262
+ device: Optional[torch.device] = None,
263
+ dtype: Optional[torch.dtype] = None,
264
+ ):
265
+ device = device or self._execution_device
266
+ dtype = dtype or self.text_encoder.dtype
267
+
268
+ prompt = [prompt] if isinstance(prompt, str) else prompt
269
+ batch_size = len(prompt)
270
+
271
+ if self.text_encoder_3 is None:
272
+ return torch.zeros(
273
+ (
274
+ batch_size * num_images_per_prompt,
275
+ self.tokenizer_max_length,
276
+ self.transformer.config.joint_attention_dim,
277
+ ),
278
+ device=device,
279
+ dtype=dtype,
280
+ )
281
+
282
+ text_inputs = self.tokenizer_3(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=max_sequence_length,
286
+ truncation=True,
287
+ add_special_tokens=True,
288
+ return_tensors="pt",
289
+ )
290
+ text_input_ids = text_inputs.input_ids
291
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
292
+
293
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
295
+ logger.warning(
296
+ "The following part of your input was truncated because `max_sequence_length` is set to "
297
+ f" {max_sequence_length} tokens: {removed_text}"
298
+ )
299
+
300
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
301
+
302
+ dtype = self.text_encoder_3.dtype
303
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
304
+
305
+ _, seq_len, _ = prompt_embeds.shape
306
+
307
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
308
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
309
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
310
+
311
+ return prompt_embeds
312
+
313
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
314
+ def _get_clip_prompt_embeds(
315
+ self,
316
+ prompt: Union[str, List[str]],
317
+ num_images_per_prompt: int = 1,
318
+ device: Optional[torch.device] = None,
319
+ clip_skip: Optional[int] = None,
320
+ clip_model_index: int = 0,
321
+ ):
322
+ device = device or self._execution_device
323
+
324
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
325
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
326
+
327
+ tokenizer = clip_tokenizers[clip_model_index]
328
+ text_encoder = clip_text_encoders[clip_model_index]
329
+
330
+ prompt = [prompt] if isinstance(prompt, str) else prompt
331
+ batch_size = len(prompt)
332
+
333
+ text_inputs = tokenizer(
334
+ prompt,
335
+ padding="max_length",
336
+ max_length=self.tokenizer_max_length,
337
+ truncation=True,
338
+ return_tensors="pt",
339
+ )
340
+
341
+ text_input_ids = text_inputs.input_ids
342
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
343
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
344
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
345
+ logger.warning(
346
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
347
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
348
+ )
349
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
350
+ pooled_prompt_embeds = prompt_embeds[0]
351
+
352
+ if clip_skip is None:
353
+ prompt_embeds = prompt_embeds.hidden_states[-2]
354
+ else:
355
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
356
+
357
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
358
+
359
+ _, seq_len, _ = prompt_embeds.shape
360
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
361
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
362
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
363
+
364
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
365
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
366
+
367
+ return prompt_embeds, pooled_prompt_embeds
368
+
369
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
370
+ def encode_prompt(
371
+ self,
372
+ prompt: Union[str, List[str]],
373
+ prompt_2: Union[str, List[str]],
374
+ prompt_3: Union[str, List[str]],
375
+ device: Optional[torch.device] = None,
376
+ num_images_per_prompt: int = 1,
377
+ do_classifier_free_guidance: bool = True,
378
+ negative_prompt: Optional[Union[str, List[str]]] = None,
379
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
380
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
381
+ prompt_embeds: Optional[torch.FloatTensor] = None,
382
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
383
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
384
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
385
+ clip_skip: Optional[int] = None,
386
+ max_sequence_length: int = 256,
387
+ lora_scale: Optional[float] = None,
388
+ ):
389
+ r"""
390
+
391
+ Args:
392
+ prompt (`str` or `List[str]`, *optional*):
393
+ prompt to be encoded
394
+ prompt_2 (`str` or `List[str]`, *optional*):
395
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
396
+ used in all text-encoders
397
+ prompt_3 (`str` or `List[str]`, *optional*):
398
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
399
+ used in all text-encoders
400
+ device: (`torch.device`):
401
+ torch device
402
+ num_images_per_prompt (`int`):
403
+ number of images that should be generated per prompt
404
+ do_classifier_free_guidance (`bool`):
405
+ whether to use classifier free guidance or not
406
+ negative_prompt (`str` or `List[str]`, *optional*):
407
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
408
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
409
+ less than `1`).
410
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
411
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
412
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
413
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
414
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
415
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
416
+ prompt_embeds (`torch.FloatTensor`, *optional*):
417
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
418
+ provided, text embeddings will be generated from `prompt` input argument.
419
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
420
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
421
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
422
+ argument.
423
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
424
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
425
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
426
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
427
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
428
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
429
+ input argument.
430
+ clip_skip (`int`, *optional*):
431
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
432
+ the output of the pre-final layer will be used for computing the prompt embeddings.
433
+ lora_scale (`float`, *optional*):
434
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
435
+ """
436
+ device = device or self._execution_device
437
+
438
+ # set lora scale so that monkey patched LoRA
439
+ # function of text encoder can correctly access it
440
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
441
+ self._lora_scale = lora_scale
442
+
443
+ # dynamically adjust the LoRA scale
444
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
445
+ scale_lora_layers(self.text_encoder, lora_scale)
446
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
447
+ scale_lora_layers(self.text_encoder_2, lora_scale)
448
+
449
+ prompt = [prompt] if isinstance(prompt, str) else prompt
450
+ if prompt is not None:
451
+ batch_size = len(prompt)
452
+ else:
453
+ batch_size = prompt_embeds.shape[0]
454
+
455
+ if prompt_embeds is None:
456
+ prompt_2 = prompt_2 or prompt
457
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
458
+
459
+ prompt_3 = prompt_3 or prompt
460
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
461
+
462
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
463
+ prompt=prompt,
464
+ device=device,
465
+ num_images_per_prompt=num_images_per_prompt,
466
+ clip_skip=clip_skip,
467
+ clip_model_index=0,
468
+ )
469
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
470
+ prompt=prompt_2,
471
+ device=device,
472
+ num_images_per_prompt=num_images_per_prompt,
473
+ clip_skip=clip_skip,
474
+ clip_model_index=1,
475
+ )
476
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
477
+
478
+ t5_prompt_embed = self._get_t5_prompt_embeds(
479
+ prompt=prompt_3,
480
+ num_images_per_prompt=num_images_per_prompt,
481
+ max_sequence_length=max_sequence_length,
482
+ device=device,
483
+ )
484
+
485
+ clip_prompt_embeds = torch.nn.functional.pad(
486
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
487
+ )
488
+
489
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
490
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
491
+
492
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
493
+ negative_prompt = negative_prompt or ""
494
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
495
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
496
+
497
+ # normalize str to list
498
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
499
+ negative_prompt_2 = (
500
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
501
+ )
502
+ negative_prompt_3 = (
503
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
504
+ )
505
+
506
+ if prompt is not None and type(prompt) is not type(negative_prompt):
507
+ raise TypeError(
508
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
509
+ f" {type(prompt)}."
510
+ )
511
+ elif batch_size != len(negative_prompt):
512
+ raise ValueError(
513
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
514
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
515
+ " the batch size of `prompt`."
516
+ )
517
+
518
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
519
+ negative_prompt,
520
+ device=device,
521
+ num_images_per_prompt=num_images_per_prompt,
522
+ clip_skip=None,
523
+ clip_model_index=0,
524
+ )
525
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
526
+ negative_prompt_2,
527
+ device=device,
528
+ num_images_per_prompt=num_images_per_prompt,
529
+ clip_skip=None,
530
+ clip_model_index=1,
531
+ )
532
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
533
+
534
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
535
+ prompt=negative_prompt_3,
536
+ num_images_per_prompt=num_images_per_prompt,
537
+ max_sequence_length=max_sequence_length,
538
+ device=device,
539
+ )
540
+
541
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
542
+ negative_clip_prompt_embeds,
543
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
544
+ )
545
+
546
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
547
+ negative_pooled_prompt_embeds = torch.cat(
548
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
549
+ )
550
+
551
+ if self.text_encoder is not None:
552
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
553
+ # Retrieve the original scale by scaling back the LoRA layers
554
+ unscale_lora_layers(self.text_encoder, lora_scale)
555
+
556
+ if self.text_encoder_2 is not None:
557
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
558
+ # Retrieve the original scale by scaling back the LoRA layers
559
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
560
+
561
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
562
+
563
+ def check_inputs(
564
+ self,
565
+ prompt,
566
+ prompt_2,
567
+ prompt_3,
568
+ height,
569
+ width,
570
+ strength,
571
+ negative_prompt=None,
572
+ negative_prompt_2=None,
573
+ negative_prompt_3=None,
574
+ prompt_embeds=None,
575
+ negative_prompt_embeds=None,
576
+ pooled_prompt_embeds=None,
577
+ negative_pooled_prompt_embeds=None,
578
+ callback_on_step_end_tensor_inputs=None,
579
+ max_sequence_length=None,
580
+ ):
581
+ if (
582
+ height % (self.vae_scale_factor * self.patch_size) != 0
583
+ or width % (self.vae_scale_factor * self.patch_size) != 0
584
+ ):
585
+ raise ValueError(
586
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
587
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
588
+ )
589
+
590
+ if strength < 0 or strength > 1:
591
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
592
+
593
+ if callback_on_step_end_tensor_inputs is not None and not all(
594
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
595
+ ):
596
+ raise ValueError(
597
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
598
+ )
599
+
600
+ if prompt is not None and prompt_embeds is not None:
601
+ raise ValueError(
602
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
603
+ " only forward one of the two."
604
+ )
605
+ elif prompt_2 is not None and prompt_embeds is not None:
606
+ raise ValueError(
607
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
608
+ " only forward one of the two."
609
+ )
610
+ elif prompt_3 is not None and prompt_embeds is not None:
611
+ raise ValueError(
612
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
613
+ " only forward one of the two."
614
+ )
615
+ elif prompt is None and prompt_embeds is None:
616
+ raise ValueError(
617
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
618
+ )
619
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
620
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
621
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
622
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
623
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
624
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
625
+
626
+ if negative_prompt is not None and negative_prompt_embeds is not None:
627
+ raise ValueError(
628
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
629
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
630
+ )
631
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
632
+ raise ValueError(
633
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
634
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
635
+ )
636
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
637
+ raise ValueError(
638
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
639
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
640
+ )
641
+
642
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
643
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
644
+ raise ValueError(
645
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
646
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
647
+ f" {negative_prompt_embeds.shape}."
648
+ )
649
+
650
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
651
+ raise ValueError(
652
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
653
+ )
654
+
655
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
656
+ raise ValueError(
657
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
658
+ )
659
+
660
+ if max_sequence_length is not None and max_sequence_length > 512:
661
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
662
+
663
+ def get_timesteps(self, num_inference_steps, strength, device):
664
+ # get the original timestep using init_timestep
665
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
666
+
667
+ t_start = int(max(num_inference_steps - init_timestep, 0))
668
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
669
+ if hasattr(self.scheduler, "set_begin_index"):
670
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
671
+
672
+ return timesteps, num_inference_steps - t_start
673
+
674
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
675
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
676
+ raise ValueError(
677
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
678
+ )
679
+
680
+ image = image.to(device=device, dtype=dtype)
681
+
682
+ batch_size = batch_size * num_images_per_prompt
683
+ if image.shape[1] == self.vae.config.latent_channels:
684
+ init_latents = image
685
+
686
+ else:
687
+ if isinstance(generator, list) and len(generator) != batch_size:
688
+ raise ValueError(
689
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
690
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
691
+ )
692
+
693
+ elif isinstance(generator, list):
694
+ init_latents = [
695
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
696
+ for i in range(batch_size)
697
+ ]
698
+ init_latents = torch.cat(init_latents, dim=0)
699
+ else:
700
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
701
+
702
+ init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
703
+
704
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
705
+ # expand init_latents for batch_size
706
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
707
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
708
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
709
+ raise ValueError(
710
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
711
+ )
712
+ else:
713
+ init_latents = torch.cat([init_latents], dim=0)
714
+
715
+ shape = init_latents.shape
716
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
717
+
718
+ # get latents
719
+ init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
720
+ latents = init_latents.to(device=device, dtype=dtype)
721
+
722
+ return latents
723
+
724
+ @property
725
+ def guidance_scale(self):
726
+ return self._guidance_scale
727
+
728
+ @property
729
+ def joint_attention_kwargs(self):
730
+ return self._joint_attention_kwargs
731
+
732
+ @property
733
+ def clip_skip(self):
734
+ return self._clip_skip
735
+
736
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
737
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
738
+ # corresponds to doing no classifier free guidance.
739
+ @property
740
+ def do_classifier_free_guidance(self):
741
+ return self._guidance_scale > 1
742
+
743
+ @property
744
+ def num_timesteps(self):
745
+ return self._num_timesteps
746
+
747
+ @property
748
+ def interrupt(self):
749
+ return self._interrupt
750
+
751
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
752
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
753
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
754
+
755
+ Args:
756
+ image (`PipelineImageInput`):
757
+ Input image to be encoded.
758
+ device: (`torch.device`):
759
+ Torch device.
760
+
761
+ Returns:
762
+ `torch.Tensor`: The encoded image feature representation.
763
+ """
764
+ if not isinstance(image, torch.Tensor):
765
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
766
+
767
+ image = image.to(device=device, dtype=self.dtype)
768
+
769
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
770
+
771
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
772
+ def prepare_ip_adapter_image_embeds(
773
+ self,
774
+ ip_adapter_image: Optional[PipelineImageInput] = None,
775
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
776
+ device: Optional[torch.device] = None,
777
+ num_images_per_prompt: int = 1,
778
+ do_classifier_free_guidance: bool = True,
779
+ ) -> torch.Tensor:
780
+ """Prepares image embeddings for use in the IP-Adapter.
781
+
782
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
783
+
784
+ Args:
785
+ ip_adapter_image (`PipelineImageInput`, *optional*):
786
+ The input image to extract features from for IP-Adapter.
787
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
788
+ Precomputed image embeddings.
789
+ device: (`torch.device`, *optional*):
790
+ Torch device.
791
+ num_images_per_prompt (`int`, defaults to 1):
792
+ Number of images that should be generated per prompt.
793
+ do_classifier_free_guidance (`bool`, defaults to True):
794
+ Whether to use classifier free guidance or not.
795
+ """
796
+ device = device or self._execution_device
797
+
798
+ if ip_adapter_image_embeds is not None:
799
+ if do_classifier_free_guidance:
800
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
801
+ else:
802
+ single_image_embeds = ip_adapter_image_embeds
803
+ elif ip_adapter_image is not None:
804
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
805
+ if do_classifier_free_guidance:
806
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
807
+ else:
808
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
809
+
810
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
811
+
812
+ if do_classifier_free_guidance:
813
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
814
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
815
+
816
+ return image_embeds.to(device=device)
817
+
818
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
819
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
820
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
821
+ logger.warning(
822
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
823
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
824
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
825
+ )
826
+
827
+ super().enable_sequential_cpu_offload(*args, **kwargs)
828
+
829
+ @torch.no_grad()
830
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
831
+ def __call__(
832
+ self,
833
+ prompt: Union[str, List[str]] = None,
834
+ prompt_2: Optional[Union[str, List[str]]] = None,
835
+ prompt_3: Optional[Union[str, List[str]]] = None,
836
+ height: Optional[int] = None,
837
+ width: Optional[int] = None,
838
+ image: PipelineImageInput = None,
839
+ strength: float = 0.6,
840
+ num_inference_steps: int = 50,
841
+ sigmas: Optional[List[float]] = None,
842
+ guidance_scale: float = 7.0,
843
+ negative_prompt: Optional[Union[str, List[str]]] = None,
844
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
845
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
846
+ num_images_per_prompt: Optional[int] = 1,
847
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
848
+ latents: Optional[torch.FloatTensor] = None,
849
+ prompt_embeds: Optional[torch.FloatTensor] = None,
850
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
851
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
852
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
853
+ output_type: Optional[str] = "pil",
854
+ ip_adapter_image: Optional[PipelineImageInput] = None,
855
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
856
+ return_dict: bool = True,
857
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
858
+ clip_skip: Optional[int] = None,
859
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
860
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
861
+ max_sequence_length: int = 256,
862
+ mu: Optional[float] = None,
863
+ ):
864
+ r"""
865
+ Function invoked when calling the pipeline for generation.
866
+
867
+ Args:
868
+ prompt (`str` or `List[str]`, *optional*):
869
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
870
+ instead.
871
+ prompt_2 (`str` or `List[str]`, *optional*):
872
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
873
+ will be used instead
874
+ prompt_3 (`str` or `List[str]`, *optional*):
875
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
876
+ will be used instead
877
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
878
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
879
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
880
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
881
+ num_inference_steps (`int`, *optional*, defaults to 50):
882
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
883
+ expense of slower inference.
884
+ sigmas (`List[float]`, *optional*):
885
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
886
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
887
+ will be used.
888
+ guidance_scale (`float`, *optional*, defaults to 7.0):
889
+ Guidance scale as defined in [Classifier-Free Diffusion
890
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
891
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
892
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
893
+ the text `prompt`, usually at the expense of lower image quality.
894
+ negative_prompt (`str` or `List[str]`, *optional*):
895
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
896
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
897
+ less than `1`).
898
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
899
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
900
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
901
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
902
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
903
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
904
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
905
+ The number of images to generate per prompt.
906
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
907
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
908
+ to make generation deterministic.
909
+ latents (`torch.FloatTensor`, *optional*):
910
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
911
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
912
+ tensor will be generated by sampling using the supplied random `generator`.
913
+ prompt_embeds (`torch.FloatTensor`, *optional*):
914
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
915
+ provided, text embeddings will be generated from `prompt` input argument.
916
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
917
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
918
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
919
+ argument.
920
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
921
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
922
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
923
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
924
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
925
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
926
+ input argument.
927
+ ip_adapter_image (`PipelineImageInput`, *optional*):
928
+ Optional image input to work with IP Adapters.
929
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
930
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
931
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
932
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
933
+ output_type (`str`, *optional*, defaults to `"pil"`):
934
+ The output format of the generate image. Choose between
935
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
936
+ return_dict (`bool`, *optional*, defaults to `True`):
937
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
938
+ a plain tuple.
939
+ joint_attention_kwargs (`dict`, *optional*):
940
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
941
+ `self.processor` in
942
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
943
+ callback_on_step_end (`Callable`, *optional*):
944
+ A function that calls at the end of each denoising steps during the inference. The function is called
945
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
946
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
947
+ `callback_on_step_end_tensor_inputs`.
948
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
949
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
950
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
951
+ `._callback_tensor_inputs` attribute of your pipeline class.
952
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
953
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
954
+
955
+ Examples:
956
+
957
+ Returns:
958
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
959
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
960
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
961
+ """
962
+ height = height or self.default_sample_size * self.vae_scale_factor
963
+ width = width or self.default_sample_size * self.vae_scale_factor
964
+
965
+ # 1. Check inputs. Raise error if not correct
966
+ self.check_inputs(
967
+ prompt,
968
+ prompt_2,
969
+ prompt_3,
970
+ height,
971
+ width,
972
+ strength,
973
+ negative_prompt=negative_prompt,
974
+ negative_prompt_2=negative_prompt_2,
975
+ negative_prompt_3=negative_prompt_3,
976
+ prompt_embeds=prompt_embeds,
977
+ negative_prompt_embeds=negative_prompt_embeds,
978
+ pooled_prompt_embeds=pooled_prompt_embeds,
979
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
980
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
981
+ max_sequence_length=max_sequence_length,
982
+ )
983
+
984
+ self._guidance_scale = guidance_scale
985
+ self._clip_skip = clip_skip
986
+ self._joint_attention_kwargs = joint_attention_kwargs
987
+ self._interrupt = False
988
+
989
+ # 2. Define call parameters
990
+ if prompt is not None and isinstance(prompt, str):
991
+ batch_size = 1
992
+ elif prompt is not None and isinstance(prompt, list):
993
+ batch_size = len(prompt)
994
+ else:
995
+ batch_size = prompt_embeds.shape[0]
996
+
997
+ device = self._execution_device
998
+
999
+ lora_scale = (
1000
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1001
+ )
1002
+
1003
+ (
1004
+ prompt_embeds,
1005
+ negative_prompt_embeds,
1006
+ pooled_prompt_embeds,
1007
+ negative_pooled_prompt_embeds,
1008
+ ) = self.encode_prompt(
1009
+ prompt=prompt,
1010
+ prompt_2=prompt_2,
1011
+ prompt_3=prompt_3,
1012
+ negative_prompt=negative_prompt,
1013
+ negative_prompt_2=negative_prompt_2,
1014
+ negative_prompt_3=negative_prompt_3,
1015
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1016
+ prompt_embeds=prompt_embeds,
1017
+ negative_prompt_embeds=negative_prompt_embeds,
1018
+ pooled_prompt_embeds=pooled_prompt_embeds,
1019
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1020
+ device=device,
1021
+ clip_skip=self.clip_skip,
1022
+ num_images_per_prompt=num_images_per_prompt,
1023
+ max_sequence_length=max_sequence_length,
1024
+ lora_scale=lora_scale,
1025
+ )
1026
+
1027
+ if self.do_classifier_free_guidance:
1028
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1029
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1030
+
1031
+ # 3. Preprocess image
1032
+ image = self.image_processor.preprocess(image, height=height, width=width)
1033
+
1034
+ # 4. Prepare timesteps
1035
+ scheduler_kwargs = {}
1036
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1037
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
1038
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
1039
+ )
1040
+ mu = calculate_shift(
1041
+ image_seq_len,
1042
+ self.scheduler.config.get("base_image_seq_len", 256),
1043
+ self.scheduler.config.get("max_image_seq_len", 4096),
1044
+ self.scheduler.config.get("base_shift", 0.5),
1045
+ self.scheduler.config.get("max_shift", 1.16),
1046
+ )
1047
+ scheduler_kwargs["mu"] = mu
1048
+ elif mu is not None:
1049
+ scheduler_kwargs["mu"] = mu
1050
+ timesteps, num_inference_steps = retrieve_timesteps(
1051
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
1052
+ )
1053
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1054
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1055
+
1056
+ # 5. Prepare latent variables
1057
+ if latents is None:
1058
+ latents = self.prepare_latents(
1059
+ image,
1060
+ latent_timestep,
1061
+ batch_size,
1062
+ num_images_per_prompt,
1063
+ prompt_embeds.dtype,
1064
+ device,
1065
+ generator,
1066
+ )
1067
+
1068
+ # 6. Prepare image embeddings
1069
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1070
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1071
+ ip_adapter_image,
1072
+ ip_adapter_image_embeds,
1073
+ device,
1074
+ batch_size * num_images_per_prompt,
1075
+ self.do_classifier_free_guidance,
1076
+ )
1077
+
1078
+ if self.joint_attention_kwargs is None:
1079
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1080
+ else:
1081
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1082
+
1083
+ # 7. Denoising loop
1084
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1085
+ self._num_timesteps = len(timesteps)
1086
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1087
+ for i, t in enumerate(timesteps):
1088
+ if self.interrupt:
1089
+ continue
1090
+
1091
+ # expand the latents if we are doing classifier free guidance
1092
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1093
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1094
+ timestep = t.expand(latent_model_input.shape[0])
1095
+
1096
+ noise_pred = self.transformer(
1097
+ hidden_states=latent_model_input,
1098
+ timestep=timestep,
1099
+ encoder_hidden_states=prompt_embeds,
1100
+ pooled_projections=pooled_prompt_embeds,
1101
+ joint_attention_kwargs=self.joint_attention_kwargs,
1102
+ return_dict=False,
1103
+ )[0]
1104
+
1105
+ # perform guidance
1106
+ if self.do_classifier_free_guidance:
1107
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1108
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1109
+
1110
+ # compute the previous noisy sample x_t -> x_t-1
1111
+ latents_dtype = latents.dtype
1112
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1113
+
1114
+ if latents.dtype != latents_dtype:
1115
+ if torch.backends.mps.is_available():
1116
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1117
+ latents = latents.to(latents_dtype)
1118
+
1119
+ if callback_on_step_end is not None:
1120
+ callback_kwargs = {}
1121
+ for k in callback_on_step_end_tensor_inputs:
1122
+ callback_kwargs[k] = locals()[k]
1123
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1124
+
1125
+ latents = callback_outputs.pop("latents", latents)
1126
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1127
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1128
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1129
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1130
+ )
1131
+
1132
+ # call the callback, if provided
1133
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1134
+ progress_bar.update()
1135
+
1136
+ if XLA_AVAILABLE:
1137
+ xm.mark_step()
1138
+
1139
+ if output_type == "latent":
1140
+ image = latents
1141
+
1142
+ else:
1143
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1144
+
1145
+ image = self.vae.decode(latents, return_dict=False)[0]
1146
+ image = self.image_processor.postprocess(image, output_type=output_type)
1147
+
1148
+ # Offload all models
1149
+ self.maybe_free_model_hooks()
1150
+
1151
+ if not return_dict:
1152
+ return (image,)
1153
+
1154
+ return StableDiffusion3PipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py ADDED
@@ -0,0 +1,1379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from ...models.autoencoders import AutoencoderKL
32
+ from ...models.transformers import SD3Transformer2DModel
33
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
34
+ from ...utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from ...utils.torch_utils import randn_tensor
43
+ from ..pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import StableDiffusion3PipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import StableDiffusion3InpaintPipeline
62
+ >>> from diffusers.utils import load_image
63
+
64
+ >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained(
65
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
66
+ ... )
67
+ >>> pipe.to("cuda")
68
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
69
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
70
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
71
+ >>> source = load_image(img_url)
72
+ >>> mask = load_image(mask_url)
73
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
74
+ >>> image.save("sd3_inpainting.png")
75
+ ```
76
+ """
77
+
78
+
79
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
80
+ def calculate_shift(
81
+ image_seq_len,
82
+ base_seq_len: int = 256,
83
+ max_seq_len: int = 4096,
84
+ base_shift: float = 0.5,
85
+ max_shift: float = 1.15,
86
+ ):
87
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
88
+ b = base_shift - m * base_seq_len
89
+ mu = image_seq_len * m + b
90
+ return mu
91
+
92
+
93
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
94
+ def retrieve_latents(
95
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
96
+ ):
97
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
98
+ return encoder_output.latent_dist.sample(generator)
99
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
100
+ return encoder_output.latent_dist.mode()
101
+ elif hasattr(encoder_output, "latents"):
102
+ return encoder_output.latents
103
+ else:
104
+ raise AttributeError("Could not access latents of provided encoder_output")
105
+
106
+
107
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
108
+ def retrieve_timesteps(
109
+ scheduler,
110
+ num_inference_steps: Optional[int] = None,
111
+ device: Optional[Union[str, torch.device]] = None,
112
+ timesteps: Optional[List[int]] = None,
113
+ sigmas: Optional[List[float]] = None,
114
+ **kwargs,
115
+ ):
116
+ r"""
117
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
118
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
119
+
120
+ Args:
121
+ scheduler (`SchedulerMixin`):
122
+ The scheduler to get timesteps from.
123
+ num_inference_steps (`int`):
124
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
125
+ must be `None`.
126
+ device (`str` or `torch.device`, *optional*):
127
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
128
+ timesteps (`List[int]`, *optional*):
129
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
130
+ `num_inference_steps` and `sigmas` must be `None`.
131
+ sigmas (`List[float]`, *optional*):
132
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
133
+ `num_inference_steps` and `timesteps` must be `None`.
134
+
135
+ Returns:
136
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
137
+ second element is the number of inference steps.
138
+ """
139
+ if timesteps is not None and sigmas is not None:
140
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
141
+ if timesteps is not None:
142
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
143
+ if not accepts_timesteps:
144
+ raise ValueError(
145
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146
+ f" timestep schedules. Please check whether you are using the correct scheduler."
147
+ )
148
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ num_inference_steps = len(timesteps)
151
+ elif sigmas is not None:
152
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
153
+ if not accept_sigmas:
154
+ raise ValueError(
155
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
157
+ )
158
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
159
+ timesteps = scheduler.timesteps
160
+ num_inference_steps = len(timesteps)
161
+ else:
162
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
163
+ timesteps = scheduler.timesteps
164
+ return timesteps, num_inference_steps
165
+
166
+
167
+ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
168
+ r"""
169
+ Args:
170
+ transformer ([`SD3Transformer2DModel`]):
171
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
172
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
173
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
174
+ vae ([`AutoencoderKL`]):
175
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
176
+ text_encoder ([`CLIPTextModelWithProjection`]):
177
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
178
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
179
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
180
+ as its dimension.
181
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
182
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
183
+ specifically the
184
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
185
+ variant.
186
+ text_encoder_3 ([`T5EncoderModel`]):
187
+ Frozen text-encoder. Stable Diffusion 3 uses
188
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
189
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
190
+ tokenizer (`CLIPTokenizer`):
191
+ Tokenizer of class
192
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
193
+ tokenizer_2 (`CLIPTokenizer`):
194
+ Second Tokenizer of class
195
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
196
+ tokenizer_3 (`T5TokenizerFast`):
197
+ Tokenizer of class
198
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199
+ image_encoder (`SiglipVisionModel`, *optional*):
200
+ Pre-trained Vision Model for IP Adapter.
201
+ feature_extractor (`SiglipImageProcessor`, *optional*):
202
+ Image processor for IP Adapter.
203
+ """
204
+
205
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
206
+ _optional_components = ["image_encoder", "feature_extractor"]
207
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
208
+
209
+ def __init__(
210
+ self,
211
+ transformer: SD3Transformer2DModel,
212
+ scheduler: FlowMatchEulerDiscreteScheduler,
213
+ vae: AutoencoderKL,
214
+ text_encoder: CLIPTextModelWithProjection,
215
+ tokenizer: CLIPTokenizer,
216
+ text_encoder_2: CLIPTextModelWithProjection,
217
+ tokenizer_2: CLIPTokenizer,
218
+ text_encoder_3: T5EncoderModel,
219
+ tokenizer_3: T5TokenizerFast,
220
+ image_encoder: Optional[SiglipVisionModel] = None,
221
+ feature_extractor: Optional[SiglipImageProcessor] = None,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.register_modules(
226
+ vae=vae,
227
+ text_encoder=text_encoder,
228
+ text_encoder_2=text_encoder_2,
229
+ text_encoder_3=text_encoder_3,
230
+ tokenizer=tokenizer,
231
+ tokenizer_2=tokenizer_2,
232
+ tokenizer_3=tokenizer_3,
233
+ transformer=transformer,
234
+ scheduler=scheduler,
235
+ image_encoder=image_encoder,
236
+ feature_extractor=feature_extractor,
237
+ )
238
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
239
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
240
+ self.image_processor = VaeImageProcessor(
241
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
242
+ )
243
+ self.mask_processor = VaeImageProcessor(
244
+ vae_scale_factor=self.vae_scale_factor,
245
+ vae_latent_channels=latent_channels,
246
+ do_normalize=False,
247
+ do_binarize=True,
248
+ do_convert_grayscale=True,
249
+ )
250
+ self.tokenizer_max_length = (
251
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
252
+ )
253
+ self.default_sample_size = (
254
+ self.transformer.config.sample_size
255
+ if hasattr(self, "transformer") and self.transformer is not None
256
+ else 128
257
+ )
258
+ self.patch_size = (
259
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
260
+ )
261
+
262
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
263
+ def _get_t5_prompt_embeds(
264
+ self,
265
+ prompt: Union[str, List[str]] = None,
266
+ num_images_per_prompt: int = 1,
267
+ max_sequence_length: int = 256,
268
+ device: Optional[torch.device] = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ ):
271
+ device = device or self._execution_device
272
+ dtype = dtype or self.text_encoder.dtype
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ batch_size = len(prompt)
276
+
277
+ if self.text_encoder_3 is None:
278
+ return torch.zeros(
279
+ (
280
+ batch_size * num_images_per_prompt,
281
+ self.tokenizer_max_length,
282
+ self.transformer.config.joint_attention_dim,
283
+ ),
284
+ device=device,
285
+ dtype=dtype,
286
+ )
287
+
288
+ text_inputs = self.tokenizer_3(
289
+ prompt,
290
+ padding="max_length",
291
+ max_length=max_sequence_length,
292
+ truncation=True,
293
+ add_special_tokens=True,
294
+ return_tensors="pt",
295
+ )
296
+ text_input_ids = text_inputs.input_ids
297
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
298
+
299
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
300
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
301
+ logger.warning(
302
+ "The following part of your input was truncated because `max_sequence_length` is set to "
303
+ f" {max_sequence_length} tokens: {removed_text}"
304
+ )
305
+
306
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
307
+
308
+ dtype = self.text_encoder_3.dtype
309
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
310
+
311
+ _, seq_len, _ = prompt_embeds.shape
312
+
313
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
314
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
315
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
316
+
317
+ return prompt_embeds
318
+
319
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
320
+ def _get_clip_prompt_embeds(
321
+ self,
322
+ prompt: Union[str, List[str]],
323
+ num_images_per_prompt: int = 1,
324
+ device: Optional[torch.device] = None,
325
+ clip_skip: Optional[int] = None,
326
+ clip_model_index: int = 0,
327
+ ):
328
+ device = device or self._execution_device
329
+
330
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
331
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
332
+
333
+ tokenizer = clip_tokenizers[clip_model_index]
334
+ text_encoder = clip_text_encoders[clip_model_index]
335
+
336
+ prompt = [prompt] if isinstance(prompt, str) else prompt
337
+ batch_size = len(prompt)
338
+
339
+ text_inputs = tokenizer(
340
+ prompt,
341
+ padding="max_length",
342
+ max_length=self.tokenizer_max_length,
343
+ truncation=True,
344
+ return_tensors="pt",
345
+ )
346
+
347
+ text_input_ids = text_inputs.input_ids
348
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
349
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
350
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
351
+ logger.warning(
352
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
353
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
354
+ )
355
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
356
+ pooled_prompt_embeds = prompt_embeds[0]
357
+
358
+ if clip_skip is None:
359
+ prompt_embeds = prompt_embeds.hidden_states[-2]
360
+ else:
361
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
362
+
363
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
364
+
365
+ _, seq_len, _ = prompt_embeds.shape
366
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
367
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
368
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
369
+
370
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
371
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
372
+
373
+ return prompt_embeds, pooled_prompt_embeds
374
+
375
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
376
+ def encode_prompt(
377
+ self,
378
+ prompt: Union[str, List[str]],
379
+ prompt_2: Union[str, List[str]],
380
+ prompt_3: Union[str, List[str]],
381
+ device: Optional[torch.device] = None,
382
+ num_images_per_prompt: int = 1,
383
+ do_classifier_free_guidance: bool = True,
384
+ negative_prompt: Optional[Union[str, List[str]]] = None,
385
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
386
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
387
+ prompt_embeds: Optional[torch.FloatTensor] = None,
388
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
389
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
390
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
391
+ clip_skip: Optional[int] = None,
392
+ max_sequence_length: int = 256,
393
+ lora_scale: Optional[float] = None,
394
+ ):
395
+ r"""
396
+
397
+ Args:
398
+ prompt (`str` or `List[str]`, *optional*):
399
+ prompt to be encoded
400
+ prompt_2 (`str` or `List[str]`, *optional*):
401
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
402
+ used in all text-encoders
403
+ prompt_3 (`str` or `List[str]`, *optional*):
404
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
405
+ used in all text-encoders
406
+ device: (`torch.device`):
407
+ torch device
408
+ num_images_per_prompt (`int`):
409
+ number of images that should be generated per prompt
410
+ do_classifier_free_guidance (`bool`):
411
+ whether to use classifier free guidance or not
412
+ negative_prompt (`str` or `List[str]`, *optional*):
413
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
414
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
415
+ less than `1`).
416
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
417
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
418
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
419
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
420
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
421
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
422
+ prompt_embeds (`torch.FloatTensor`, *optional*):
423
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
424
+ provided, text embeddings will be generated from `prompt` input argument.
425
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
426
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
427
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
428
+ argument.
429
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
430
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
431
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
432
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
433
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
434
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
435
+ input argument.
436
+ clip_skip (`int`, *optional*):
437
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
438
+ the output of the pre-final layer will be used for computing the prompt embeddings.
439
+ lora_scale (`float`, *optional*):
440
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
441
+ """
442
+ device = device or self._execution_device
443
+
444
+ # set lora scale so that monkey patched LoRA
445
+ # function of text encoder can correctly access it
446
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
447
+ self._lora_scale = lora_scale
448
+
449
+ # dynamically adjust the LoRA scale
450
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
451
+ scale_lora_layers(self.text_encoder, lora_scale)
452
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
453
+ scale_lora_layers(self.text_encoder_2, lora_scale)
454
+
455
+ prompt = [prompt] if isinstance(prompt, str) else prompt
456
+ if prompt is not None:
457
+ batch_size = len(prompt)
458
+ else:
459
+ batch_size = prompt_embeds.shape[0]
460
+
461
+ if prompt_embeds is None:
462
+ prompt_2 = prompt_2 or prompt
463
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
464
+
465
+ prompt_3 = prompt_3 or prompt
466
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
467
+
468
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
469
+ prompt=prompt,
470
+ device=device,
471
+ num_images_per_prompt=num_images_per_prompt,
472
+ clip_skip=clip_skip,
473
+ clip_model_index=0,
474
+ )
475
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
476
+ prompt=prompt_2,
477
+ device=device,
478
+ num_images_per_prompt=num_images_per_prompt,
479
+ clip_skip=clip_skip,
480
+ clip_model_index=1,
481
+ )
482
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
483
+
484
+ t5_prompt_embed = self._get_t5_prompt_embeds(
485
+ prompt=prompt_3,
486
+ num_images_per_prompt=num_images_per_prompt,
487
+ max_sequence_length=max_sequence_length,
488
+ device=device,
489
+ )
490
+
491
+ clip_prompt_embeds = torch.nn.functional.pad(
492
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
493
+ )
494
+
495
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
496
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
497
+
498
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
499
+ negative_prompt = negative_prompt or ""
500
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
501
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
502
+
503
+ # normalize str to list
504
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
505
+ negative_prompt_2 = (
506
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
507
+ )
508
+ negative_prompt_3 = (
509
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
510
+ )
511
+
512
+ if prompt is not None and type(prompt) is not type(negative_prompt):
513
+ raise TypeError(
514
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
515
+ f" {type(prompt)}."
516
+ )
517
+ elif batch_size != len(negative_prompt):
518
+ raise ValueError(
519
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
520
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
521
+ " the batch size of `prompt`."
522
+ )
523
+
524
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
525
+ negative_prompt,
526
+ device=device,
527
+ num_images_per_prompt=num_images_per_prompt,
528
+ clip_skip=None,
529
+ clip_model_index=0,
530
+ )
531
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
532
+ negative_prompt_2,
533
+ device=device,
534
+ num_images_per_prompt=num_images_per_prompt,
535
+ clip_skip=None,
536
+ clip_model_index=1,
537
+ )
538
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
539
+
540
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
541
+ prompt=negative_prompt_3,
542
+ num_images_per_prompt=num_images_per_prompt,
543
+ max_sequence_length=max_sequence_length,
544
+ device=device,
545
+ )
546
+
547
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
548
+ negative_clip_prompt_embeds,
549
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
550
+ )
551
+
552
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
553
+ negative_pooled_prompt_embeds = torch.cat(
554
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
555
+ )
556
+
557
+ if self.text_encoder is not None:
558
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
559
+ # Retrieve the original scale by scaling back the LoRA layers
560
+ unscale_lora_layers(self.text_encoder, lora_scale)
561
+
562
+ if self.text_encoder_2 is not None:
563
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
564
+ # Retrieve the original scale by scaling back the LoRA layers
565
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
566
+
567
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs
570
+ def check_inputs(
571
+ self,
572
+ prompt,
573
+ prompt_2,
574
+ prompt_3,
575
+ height,
576
+ width,
577
+ strength,
578
+ negative_prompt=None,
579
+ negative_prompt_2=None,
580
+ negative_prompt_3=None,
581
+ prompt_embeds=None,
582
+ negative_prompt_embeds=None,
583
+ pooled_prompt_embeds=None,
584
+ negative_pooled_prompt_embeds=None,
585
+ callback_on_step_end_tensor_inputs=None,
586
+ max_sequence_length=None,
587
+ ):
588
+ if (
589
+ height % (self.vae_scale_factor * self.patch_size) != 0
590
+ or width % (self.vae_scale_factor * self.patch_size) != 0
591
+ ):
592
+ raise ValueError(
593
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
594
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
595
+ )
596
+
597
+ if strength < 0 or strength > 1:
598
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
599
+
600
+ if callback_on_step_end_tensor_inputs is not None and not all(
601
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
602
+ ):
603
+ raise ValueError(
604
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
605
+ )
606
+
607
+ if prompt is not None and prompt_embeds is not None:
608
+ raise ValueError(
609
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
610
+ " only forward one of the two."
611
+ )
612
+ elif prompt_2 is not None and prompt_embeds is not None:
613
+ raise ValueError(
614
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
615
+ " only forward one of the two."
616
+ )
617
+ elif prompt_3 is not None and prompt_embeds is not None:
618
+ raise ValueError(
619
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
620
+ " only forward one of the two."
621
+ )
622
+ elif prompt is None and prompt_embeds is None:
623
+ raise ValueError(
624
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
625
+ )
626
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
627
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
628
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
629
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
630
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
631
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
632
+
633
+ if negative_prompt is not None and negative_prompt_embeds is not None:
634
+ raise ValueError(
635
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
636
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
637
+ )
638
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
639
+ raise ValueError(
640
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
641
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
642
+ )
643
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
644
+ raise ValueError(
645
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
646
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
647
+ )
648
+
649
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
650
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
651
+ raise ValueError(
652
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
653
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
654
+ f" {negative_prompt_embeds.shape}."
655
+ )
656
+
657
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
658
+ raise ValueError(
659
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
660
+ )
661
+
662
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
663
+ raise ValueError(
664
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
665
+ )
666
+
667
+ if max_sequence_length is not None and max_sequence_length > 512:
668
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
669
+
670
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
671
+ def get_timesteps(self, num_inference_steps, strength, device):
672
+ # get the original timestep using init_timestep
673
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
674
+
675
+ t_start = int(max(num_inference_steps - init_timestep, 0))
676
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
677
+ if hasattr(self.scheduler, "set_begin_index"):
678
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
679
+
680
+ return timesteps, num_inference_steps - t_start
681
+
682
+ def prepare_latents(
683
+ self,
684
+ batch_size,
685
+ num_channels_latents,
686
+ height,
687
+ width,
688
+ dtype,
689
+ device,
690
+ generator,
691
+ latents=None,
692
+ image=None,
693
+ timestep=None,
694
+ is_strength_max=True,
695
+ return_noise=False,
696
+ return_image_latents=False,
697
+ ):
698
+ shape = (
699
+ batch_size,
700
+ num_channels_latents,
701
+ int(height) // self.vae_scale_factor,
702
+ int(width) // self.vae_scale_factor,
703
+ )
704
+ if isinstance(generator, list) and len(generator) != batch_size:
705
+ raise ValueError(
706
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
707
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
708
+ )
709
+
710
+ if (image is None or timestep is None) and not is_strength_max:
711
+ raise ValueError(
712
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
713
+ "However, either the image or the noise timestep has not been provided."
714
+ )
715
+
716
+ if return_image_latents or (latents is None and not is_strength_max):
717
+ image = image.to(device=device, dtype=dtype)
718
+
719
+ if image.shape[1] == 16:
720
+ image_latents = image
721
+ else:
722
+ image_latents = self._encode_vae_image(image=image, generator=generator)
723
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
724
+
725
+ if latents is None:
726
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
727
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
728
+ latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
729
+ else:
730
+ noise = latents.to(device)
731
+ latents = noise
732
+
733
+ outputs = (latents,)
734
+
735
+ if return_noise:
736
+ outputs += (noise,)
737
+
738
+ if return_image_latents:
739
+ outputs += (image_latents,)
740
+
741
+ return outputs
742
+
743
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
744
+ if isinstance(generator, list):
745
+ image_latents = [
746
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
747
+ for i in range(image.shape[0])
748
+ ]
749
+ image_latents = torch.cat(image_latents, dim=0)
750
+ else:
751
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
752
+
753
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
754
+
755
+ return image_latents
756
+
757
+ def prepare_mask_latents(
758
+ self,
759
+ mask,
760
+ masked_image,
761
+ batch_size,
762
+ num_images_per_prompt,
763
+ height,
764
+ width,
765
+ dtype,
766
+ device,
767
+ generator,
768
+ do_classifier_free_guidance,
769
+ ):
770
+ # resize the mask to latents shape as we concatenate the mask to the latents
771
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
772
+ # and half precision
773
+ mask = torch.nn.functional.interpolate(
774
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
775
+ )
776
+ mask = mask.to(device=device, dtype=dtype)
777
+
778
+ batch_size = batch_size * num_images_per_prompt
779
+
780
+ masked_image = masked_image.to(device=device, dtype=dtype)
781
+
782
+ if masked_image.shape[1] == 16:
783
+ masked_image_latents = masked_image
784
+ else:
785
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
786
+
787
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
788
+
789
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
790
+ if mask.shape[0] < batch_size:
791
+ if not batch_size % mask.shape[0] == 0:
792
+ raise ValueError(
793
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
794
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
795
+ " of masks that you pass is divisible by the total requested batch size."
796
+ )
797
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
798
+ if masked_image_latents.shape[0] < batch_size:
799
+ if not batch_size % masked_image_latents.shape[0] == 0:
800
+ raise ValueError(
801
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
802
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
803
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
804
+ )
805
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
806
+
807
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
808
+ masked_image_latents = (
809
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
810
+ )
811
+
812
+ # aligning device to prevent device errors when concating it with the latent model input
813
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
814
+ return mask, masked_image_latents
815
+
816
+ @property
817
+ def guidance_scale(self):
818
+ return self._guidance_scale
819
+
820
+ @property
821
+ def clip_skip(self):
822
+ return self._clip_skip
823
+
824
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
825
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
826
+ # corresponds to doing no classifier free guidance.
827
+ @property
828
+ def do_classifier_free_guidance(self):
829
+ return self._guidance_scale > 1
830
+
831
+ @property
832
+ def joint_attention_kwargs(self):
833
+ return self._joint_attention_kwargs
834
+
835
+ @property
836
+ def num_timesteps(self):
837
+ return self._num_timesteps
838
+
839
+ @property
840
+ def interrupt(self):
841
+ return self._interrupt
842
+
843
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
844
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
845
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
846
+
847
+ Args:
848
+ image (`PipelineImageInput`):
849
+ Input image to be encoded.
850
+ device: (`torch.device`):
851
+ Torch device.
852
+
853
+ Returns:
854
+ `torch.Tensor`: The encoded image feature representation.
855
+ """
856
+ if not isinstance(image, torch.Tensor):
857
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
858
+
859
+ image = image.to(device=device, dtype=self.dtype)
860
+
861
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
862
+
863
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
864
+ def prepare_ip_adapter_image_embeds(
865
+ self,
866
+ ip_adapter_image: Optional[PipelineImageInput] = None,
867
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
868
+ device: Optional[torch.device] = None,
869
+ num_images_per_prompt: int = 1,
870
+ do_classifier_free_guidance: bool = True,
871
+ ) -> torch.Tensor:
872
+ """Prepares image embeddings for use in the IP-Adapter.
873
+
874
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
875
+
876
+ Args:
877
+ ip_adapter_image (`PipelineImageInput`, *optional*):
878
+ The input image to extract features from for IP-Adapter.
879
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
880
+ Precomputed image embeddings.
881
+ device: (`torch.device`, *optional*):
882
+ Torch device.
883
+ num_images_per_prompt (`int`, defaults to 1):
884
+ Number of images that should be generated per prompt.
885
+ do_classifier_free_guidance (`bool`, defaults to True):
886
+ Whether to use classifier free guidance or not.
887
+ """
888
+ device = device or self._execution_device
889
+
890
+ if ip_adapter_image_embeds is not None:
891
+ if do_classifier_free_guidance:
892
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
893
+ else:
894
+ single_image_embeds = ip_adapter_image_embeds
895
+ elif ip_adapter_image is not None:
896
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
897
+ if do_classifier_free_guidance:
898
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
899
+ else:
900
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
901
+
902
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
903
+
904
+ if do_classifier_free_guidance:
905
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
906
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
907
+
908
+ return image_embeds.to(device=device)
909
+
910
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
911
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
912
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
913
+ logger.warning(
914
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
915
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
916
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
917
+ )
918
+
919
+ super().enable_sequential_cpu_offload(*args, **kwargs)
920
+
921
+ @torch.no_grad()
922
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
923
+ def __call__(
924
+ self,
925
+ prompt: Union[str, List[str]] = None,
926
+ prompt_2: Optional[Union[str, List[str]]] = None,
927
+ prompt_3: Optional[Union[str, List[str]]] = None,
928
+ image: PipelineImageInput = None,
929
+ mask_image: PipelineImageInput = None,
930
+ masked_image_latents: PipelineImageInput = None,
931
+ height: int = None,
932
+ width: int = None,
933
+ padding_mask_crop: Optional[int] = None,
934
+ strength: float = 0.6,
935
+ num_inference_steps: int = 50,
936
+ sigmas: Optional[List[float]] = None,
937
+ guidance_scale: float = 7.0,
938
+ negative_prompt: Optional[Union[str, List[str]]] = None,
939
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
940
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
941
+ num_images_per_prompt: Optional[int] = 1,
942
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
943
+ latents: Optional[torch.Tensor] = None,
944
+ prompt_embeds: Optional[torch.Tensor] = None,
945
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
946
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
947
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
948
+ ip_adapter_image: Optional[PipelineImageInput] = None,
949
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
950
+ output_type: Optional[str] = "pil",
951
+ return_dict: bool = True,
952
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
953
+ clip_skip: Optional[int] = None,
954
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
955
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
956
+ max_sequence_length: int = 256,
957
+ mu: Optional[float] = None,
958
+ ):
959
+ r"""
960
+ Function invoked when calling the pipeline for generation.
961
+
962
+ Args:
963
+ prompt (`str` or `List[str]`, *optional*):
964
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
965
+ instead.
966
+ prompt_2 (`str` or `List[str]`, *optional*):
967
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
968
+ will be used instead
969
+ prompt_3 (`str` or `List[str]`, *optional*):
970
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
971
+ will be used instead
972
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
973
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
974
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
975
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
976
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
977
+ latents as `image`, but if passing latents directly it is not encoded again.
978
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
979
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
980
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
981
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
982
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
983
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
984
+ 1)`, or `(H, W)`.
985
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
986
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
987
+ latents tensor will be generated by `mask_image`.
988
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
989
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
990
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
991
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
992
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
993
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
994
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
995
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
996
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
997
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
998
+ the image is large and contain information irrelevant for inpainting, such as background.
999
+ strength (`float`, *optional*, defaults to 1.0):
1000
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1001
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
1002
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
1003
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
1004
+ essentially ignores `image`.
1005
+ num_inference_steps (`int`, *optional*, defaults to 50):
1006
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1007
+ expense of slower inference.
1008
+ sigmas (`List[float]`, *optional*):
1009
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1010
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1011
+ will be used.
1012
+ guidance_scale (`float`, *optional*, defaults to 7.0):
1013
+ Guidance scale as defined in [Classifier-Free Diffusion
1014
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
1015
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
1016
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
1017
+ the text `prompt`, usually at the expense of lower image quality.
1018
+ negative_prompt (`str` or `List[str]`, *optional*):
1019
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1020
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1021
+ less than `1`).
1022
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1023
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1024
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
1025
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
1026
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
1027
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
1028
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1029
+ The number of images to generate per prompt.
1030
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1031
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1032
+ to make generation deterministic.
1033
+ latents (`torch.FloatTensor`, *optional*):
1034
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1035
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1036
+ tensor will be generated by sampling using the supplied random `generator`.
1037
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1038
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1039
+ provided, text embeddings will be generated from `prompt` input argument.
1040
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1041
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1042
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1043
+ argument.
1044
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1045
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1046
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1047
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1048
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1049
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1050
+ input argument.
1051
+ ip_adapter_image (`PipelineImageInput`, *optional*):
1052
+ Optional image input to work with IP Adapters.
1053
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
1054
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
1055
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
1056
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1057
+ output_type (`str`, *optional*, defaults to `"pil"`):
1058
+ The output format of the generate image. Choose between
1059
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1060
+ return_dict (`bool`, *optional*, defaults to `True`):
1061
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
1062
+ a plain tuple.
1063
+ joint_attention_kwargs (`dict`, *optional*):
1064
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1065
+ `self.processor` in
1066
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1067
+ callback_on_step_end (`Callable`, *optional*):
1068
+ A function that calls at the end of each denoising steps during the inference. The function is called
1069
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1070
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1071
+ `callback_on_step_end_tensor_inputs`.
1072
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1073
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1074
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1075
+ `._callback_tensor_inputs` attribute of your pipeline class.
1076
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1077
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
1078
+
1079
+ Examples:
1080
+
1081
+ Returns:
1082
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1083
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1084
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1085
+ """
1086
+
1087
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1088
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1089
+
1090
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
1091
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
1092
+
1093
+ # 1. Check inputs. Raise error if not correct
1094
+ self.check_inputs(
1095
+ prompt,
1096
+ prompt_2,
1097
+ prompt_3,
1098
+ height,
1099
+ width,
1100
+ strength,
1101
+ negative_prompt=negative_prompt,
1102
+ negative_prompt_2=negative_prompt_2,
1103
+ negative_prompt_3=negative_prompt_3,
1104
+ prompt_embeds=prompt_embeds,
1105
+ negative_prompt_embeds=negative_prompt_embeds,
1106
+ pooled_prompt_embeds=pooled_prompt_embeds,
1107
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1108
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1109
+ max_sequence_length=max_sequence_length,
1110
+ )
1111
+
1112
+ self._guidance_scale = guidance_scale
1113
+ self._clip_skip = clip_skip
1114
+ self._joint_attention_kwargs = joint_attention_kwargs
1115
+ self._interrupt = False
1116
+
1117
+ # 2. Define call parameters
1118
+ if prompt is not None and isinstance(prompt, str):
1119
+ batch_size = 1
1120
+ elif prompt is not None and isinstance(prompt, list):
1121
+ batch_size = len(prompt)
1122
+ else:
1123
+ batch_size = prompt_embeds.shape[0]
1124
+
1125
+ device = self._execution_device
1126
+
1127
+ (
1128
+ prompt_embeds,
1129
+ negative_prompt_embeds,
1130
+ pooled_prompt_embeds,
1131
+ negative_pooled_prompt_embeds,
1132
+ ) = self.encode_prompt(
1133
+ prompt=prompt,
1134
+ prompt_2=prompt_2,
1135
+ prompt_3=prompt_3,
1136
+ negative_prompt=negative_prompt,
1137
+ negative_prompt_2=negative_prompt_2,
1138
+ negative_prompt_3=negative_prompt_3,
1139
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1140
+ prompt_embeds=prompt_embeds,
1141
+ negative_prompt_embeds=negative_prompt_embeds,
1142
+ pooled_prompt_embeds=pooled_prompt_embeds,
1143
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1144
+ device=device,
1145
+ clip_skip=self.clip_skip,
1146
+ num_images_per_prompt=num_images_per_prompt,
1147
+ max_sequence_length=max_sequence_length,
1148
+ )
1149
+
1150
+ if self.do_classifier_free_guidance:
1151
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1152
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1153
+
1154
+ # 3. Prepare timesteps
1155
+ scheduler_kwargs = {}
1156
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1157
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
1158
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
1159
+ )
1160
+ mu = calculate_shift(
1161
+ image_seq_len,
1162
+ self.scheduler.config.get("base_image_seq_len", 256),
1163
+ self.scheduler.config.get("max_image_seq_len", 4096),
1164
+ self.scheduler.config.get("base_shift", 0.5),
1165
+ self.scheduler.config.get("max_shift", 1.16),
1166
+ )
1167
+ scheduler_kwargs["mu"] = mu
1168
+ elif mu is not None:
1169
+ scheduler_kwargs["mu"] = mu
1170
+ timesteps, num_inference_steps = retrieve_timesteps(
1171
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
1172
+ )
1173
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1174
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1175
+ if num_inference_steps < 1:
1176
+ raise ValueError(
1177
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1178
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1179
+ )
1180
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1181
+
1182
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1183
+ is_strength_max = strength == 1.0
1184
+
1185
+ # 4. Preprocess mask and image
1186
+ if padding_mask_crop is not None:
1187
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1188
+ resize_mode = "fill"
1189
+ else:
1190
+ crops_coords = None
1191
+ resize_mode = "default"
1192
+
1193
+ original_image = image
1194
+ init_image = self.image_processor.preprocess(
1195
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1196
+ )
1197
+ init_image = init_image.to(dtype=torch.float32)
1198
+
1199
+ # 5. Prepare latent variables
1200
+ num_channels_latents = self.vae.config.latent_channels
1201
+ num_channels_transformer = self.transformer.config.in_channels
1202
+ return_image_latents = num_channels_transformer == 16
1203
+
1204
+ latents_outputs = self.prepare_latents(
1205
+ batch_size * num_images_per_prompt,
1206
+ num_channels_latents,
1207
+ height,
1208
+ width,
1209
+ prompt_embeds.dtype,
1210
+ device,
1211
+ generator,
1212
+ latents,
1213
+ image=init_image,
1214
+ timestep=latent_timestep,
1215
+ is_strength_max=is_strength_max,
1216
+ return_noise=True,
1217
+ return_image_latents=return_image_latents,
1218
+ )
1219
+
1220
+ if return_image_latents:
1221
+ latents, noise, image_latents = latents_outputs
1222
+ else:
1223
+ latents, noise = latents_outputs
1224
+
1225
+ # 6. Prepare mask latent variables
1226
+ mask_condition = self.mask_processor.preprocess(
1227
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1228
+ )
1229
+
1230
+ if masked_image_latents is None:
1231
+ masked_image = init_image * (mask_condition < 0.5)
1232
+ else:
1233
+ masked_image = masked_image_latents
1234
+
1235
+ mask, masked_image_latents = self.prepare_mask_latents(
1236
+ mask_condition,
1237
+ masked_image,
1238
+ batch_size,
1239
+ num_images_per_prompt,
1240
+ height,
1241
+ width,
1242
+ prompt_embeds.dtype,
1243
+ device,
1244
+ generator,
1245
+ self.do_classifier_free_guidance,
1246
+ )
1247
+
1248
+ # match the inpainting pipeline and will be updated with input + mask inpainting model later
1249
+ if num_channels_transformer == 33:
1250
+ # default case for runwayml/stable-diffusion-inpainting
1251
+ num_channels_mask = mask.shape[1]
1252
+ num_channels_masked_image = masked_image_latents.shape[1]
1253
+ if (
1254
+ num_channels_latents + num_channels_mask + num_channels_masked_image
1255
+ != self.transformer.config.in_channels
1256
+ ):
1257
+ raise ValueError(
1258
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1259
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1260
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1261
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1262
+ " `pipeline.transformer` or your `mask_image` or `image` input."
1263
+ )
1264
+ elif num_channels_transformer != 16:
1265
+ raise ValueError(
1266
+ f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
1267
+ )
1268
+
1269
+ # 7. Prepare image embeddings
1270
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1271
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1272
+ ip_adapter_image,
1273
+ ip_adapter_image_embeds,
1274
+ device,
1275
+ batch_size * num_images_per_prompt,
1276
+ self.do_classifier_free_guidance,
1277
+ )
1278
+
1279
+ if self.joint_attention_kwargs is None:
1280
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1281
+ else:
1282
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1283
+
1284
+ # 8. Denoising loop
1285
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1286
+ self._num_timesteps = len(timesteps)
1287
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1288
+ for i, t in enumerate(timesteps):
1289
+ if self.interrupt:
1290
+ continue
1291
+
1292
+ # expand the latents if we are doing classifier free guidance
1293
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1294
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1295
+ timestep = t.expand(latent_model_input.shape[0])
1296
+
1297
+ if num_channels_transformer == 33:
1298
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1299
+
1300
+ noise_pred = self.transformer(
1301
+ hidden_states=latent_model_input,
1302
+ timestep=timestep,
1303
+ encoder_hidden_states=prompt_embeds,
1304
+ pooled_projections=pooled_prompt_embeds,
1305
+ joint_attention_kwargs=self.joint_attention_kwargs,
1306
+ return_dict=False,
1307
+ )[0]
1308
+
1309
+ # perform guidance
1310
+ if self.do_classifier_free_guidance:
1311
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1312
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1313
+
1314
+ # compute the previous noisy sample x_t -> x_t-1
1315
+ latents_dtype = latents.dtype
1316
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1317
+ if num_channels_transformer == 16:
1318
+ init_latents_proper = image_latents
1319
+ if self.do_classifier_free_guidance:
1320
+ init_mask, _ = mask.chunk(2)
1321
+ else:
1322
+ init_mask = mask
1323
+
1324
+ if i < len(timesteps) - 1:
1325
+ noise_timestep = timesteps[i + 1]
1326
+ init_latents_proper = self.scheduler.scale_noise(
1327
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1328
+ )
1329
+
1330
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1331
+
1332
+ if latents.dtype != latents_dtype:
1333
+ if torch.backends.mps.is_available():
1334
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1335
+ latents = latents.to(latents_dtype)
1336
+
1337
+ if callback_on_step_end is not None:
1338
+ callback_kwargs = {}
1339
+ for k in callback_on_step_end_tensor_inputs:
1340
+ callback_kwargs[k] = locals()[k]
1341
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1342
+
1343
+ latents = callback_outputs.pop("latents", latents)
1344
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1345
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1346
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1347
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1348
+ )
1349
+ mask = callback_outputs.pop("mask", mask)
1350
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1351
+
1352
+ # call the callback, if provided
1353
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1354
+ progress_bar.update()
1355
+
1356
+ if XLA_AVAILABLE:
1357
+ xm.mark_step()
1358
+
1359
+ if not output_type == "latent":
1360
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1361
+ 0
1362
+ ]
1363
+ else:
1364
+ image = latents
1365
+
1366
+ do_denormalize = [True] * image.shape[0]
1367
+
1368
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1369
+
1370
+ if padding_mask_crop is not None:
1371
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1372
+
1373
+ # Offload all models
1374
+ self.maybe_free_model_hooks()
1375
+
1376
+ if not return_dict:
1377
+ return (image,)
1378
+
1379
+ return StableDiffusion3PipelineOutput(images=image)