zjuJish commited on
Commit
c6d94d9
·
verified ·
1 Parent(s): 725ecbc

Upload layer_diff_dataset/StableDiffusionXLInpaintPipeline.py with huggingface_hub

Browse files
layer_diff_dataset/StableDiffusionXLInpaintPipeline.py ADDED
@@ -0,0 +1,1465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import os
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL
21
+ import torch
22
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
23
+
24
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
25
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
+ from ...models import AutoencoderKL, UNet2DConditionModel
27
+ from ...models.attention_processor import (
28
+ AttnProcessor2_0,
29
+ LoRAAttnProcessor2_0,
30
+ LoRAXFormersAttnProcessor,
31
+ XFormersAttnProcessor,
32
+ )
33
+ from ...models.lora import adjust_lora_scale_text_encoder
34
+ from ...schedulers import KarrasDiffusionSchedulers
35
+ from ...utils import (
36
+ deprecate,
37
+ is_accelerate_available,
38
+ is_accelerate_version,
39
+ is_invisible_watermark_available,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+ from ...utils.torch_utils import randn_tensor
44
+ from ..pipeline_utils import DiffusionPipeline
45
+ from . import StableDiffusionXLPipelineOutput
46
+
47
+
48
+ if is_invisible_watermark_available():
49
+ from .watermark import StableDiffusionXLWatermarker
50
+
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
60
+ >>> from diffusers.utils import load_image
61
+
62
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
63
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
64
+ ... torch_dtype=torch.float16,
65
+ ... variant="fp16",
66
+ ... use_safetensors=True,
67
+ ... )
68
+ >>> pipe.to("cuda")
69
+
70
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
71
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
72
+
73
+ >>> init_image = load_image(img_url).convert("RGB")
74
+ >>> mask_image = load_image(mask_url).convert("RGB")
75
+
76
+ >>> prompt = "A majestic tiger sitting on a bench"
77
+ >>> image = pipe(
78
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
79
+ ... ).images[0]
80
+ ```
81
+ """
82
+
83
+
84
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
85
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
86
+ """
87
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
88
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
89
+ """
90
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
91
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
92
+ # rescale the results from guidance (fixes overexposure)
93
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
94
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
95
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
96
+ return noise_cfg
97
+
98
+
99
+ def mask_pil_to_torch(mask, height, width):
100
+ # preprocess mask
101
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
102
+ mask = [mask]
103
+
104
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
105
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
106
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
107
+ mask = mask.astype(np.float32) / 255.0
108
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
109
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
110
+
111
+ mask = torch.from_numpy(mask)
112
+ return mask
113
+
114
+
115
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
116
+ """
117
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
118
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
119
+ ``image`` and ``1`` for the ``mask``.
120
+
121
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
122
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
123
+
124
+ Args:
125
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
126
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
127
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
128
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
129
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
130
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
131
+
132
+
133
+ Raises:
134
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
135
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
136
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
137
+ (ot the other way around).
138
+
139
+ Returns:
140
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
141
+ dimensions: ``batch x channels x height x width``.
142
+ """
143
+
144
+ # checkpoint. TOD(Yiyi) - need to clean this up later
145
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
146
+ deprecate(
147
+ "prepare_mask_and_masked_image",
148
+ "0.30.0",
149
+ deprecation_message,
150
+ )
151
+ if image is None:
152
+ raise ValueError("`image` input cannot be undefined.")
153
+
154
+ if mask is None:
155
+ raise ValueError("`mask_image` input cannot be undefined.")
156
+
157
+ if isinstance(image, torch.Tensor):
158
+ if not isinstance(mask, torch.Tensor):
159
+ mask = mask_pil_to_torch(mask, height, width)
160
+
161
+ if image.ndim == 3:
162
+ image = image.unsqueeze(0)
163
+
164
+ # Batch and add channel dim for single mask
165
+ if mask.ndim == 2:
166
+ mask = mask.unsqueeze(0).unsqueeze(0)
167
+
168
+ # Batch single mask or add channel dim
169
+ if mask.ndim == 3:
170
+ # Single batched mask, no channel dim or single mask not batched but channel dim
171
+ if mask.shape[0] == 1:
172
+ mask = mask.unsqueeze(0)
173
+
174
+ # Batched masks no channel dim
175
+ else:
176
+ mask = mask.unsqueeze(1)
177
+
178
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
179
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
180
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
181
+
182
+ # Check image is in [-1, 1]
183
+ # if image.min() < -1 or image.max() > 1:
184
+ # raise ValueError("Image should be in [-1, 1] range")
185
+
186
+ # Check mask is in [0, 1]
187
+ if mask.min() < 0 or mask.max() > 1:
188
+ raise ValueError("Mask should be in [0, 1] range")
189
+
190
+ # Binarize mask
191
+ mask[mask < 0.5] = 0
192
+ mask[mask >= 0.5] = 1
193
+
194
+ # Image as float32
195
+ image = image.to(dtype=torch.float32)
196
+ elif isinstance(mask, torch.Tensor):
197
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
198
+ else:
199
+ # preprocess image
200
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
201
+ image = [image]
202
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
203
+ # resize all images w.r.t passed height an width
204
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
205
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
206
+ image = np.concatenate(image, axis=0)
207
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
208
+ image = np.concatenate([i[None, :] for i in image], axis=0)
209
+
210
+ image = image.transpose(0, 3, 1, 2)
211
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
212
+
213
+ mask = mask_pil_to_torch(mask, height, width)
214
+ mask[mask < 0.5] = 0
215
+ mask[mask >= 0.5] = 1
216
+
217
+ if image.shape[1] == 4:
218
+ # images are in latent space and thus can't
219
+ # be masked set masked_image to None
220
+ # we assume that the checkpoint is not an inpainting
221
+ # checkpoint. TOD(Yiyi) - need to clean this up later
222
+ masked_image = None
223
+ else:
224
+ masked_image = image * (mask < 0.5)
225
+
226
+ # n.b. ensure backwards compatibility as old function does not return image
227
+ if return_image:
228
+ return mask, masked_image, image
229
+
230
+ return mask, masked_image
231
+
232
+
233
+ class StableDiffusionXLInpaintPipeline(
234
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
235
+ ):
236
+ r"""
237
+ Pipeline for text-to-image generation using Stable Diffusion XL.
238
+
239
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
240
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
241
+
242
+ In addition the pipeline inherits the following loading methods:
243
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
244
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
245
+
246
+ as well as the following saving methods:
247
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
248
+
249
+ Args:
250
+ vae ([`AutoencoderKL`]):
251
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
252
+ text_encoder ([`CLIPTextModel`]):
253
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
254
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
255
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
256
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
257
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
258
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
259
+ specifically the
260
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
261
+ variant.
262
+ tokenizer (`CLIPTokenizer`):
263
+ Tokenizer of class
264
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
265
+ tokenizer_2 (`CLIPTokenizer`):
266
+ Second Tokenizer of class
267
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
268
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
269
+ scheduler ([`SchedulerMixin`]):
270
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
271
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
272
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
273
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
274
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
275
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
276
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
277
+ `stabilityai/stable-diffusion-xl-base-1-0`.
278
+ add_watermarker (`bool`, *optional*):
279
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
280
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
281
+ watermarker will be used.
282
+ """
283
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
284
+
285
+ _optional_components = ["tokenizer", "text_encoder"]
286
+
287
+ def __init__(
288
+ self,
289
+ vae: AutoencoderKL,
290
+ text_encoder: CLIPTextModel,
291
+ text_encoder_2: CLIPTextModelWithProjection,
292
+ tokenizer: CLIPTokenizer,
293
+ tokenizer_2: CLIPTokenizer,
294
+ unet: UNet2DConditionModel,
295
+ scheduler: KarrasDiffusionSchedulers,
296
+ requires_aesthetics_score: bool = False,
297
+ force_zeros_for_empty_prompt: bool = True,
298
+ add_watermarker: Optional[bool] = None,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.register_modules(
303
+ vae=vae,
304
+ text_encoder=text_encoder,
305
+ text_encoder_2=text_encoder_2,
306
+ tokenizer=tokenizer,
307
+ tokenizer_2=tokenizer_2,
308
+ unet=unet,
309
+ scheduler=scheduler,
310
+ )
311
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
312
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
313
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
314
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
315
+ self.mask_processor = VaeImageProcessor(
316
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
317
+ )
318
+
319
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
320
+
321
+ if add_watermarker:
322
+ self.watermark = StableDiffusionXLWatermarker()
323
+ else:
324
+ self.watermark = None
325
+
326
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
327
+ def enable_vae_slicing(self):
328
+ r"""
329
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
330
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
331
+ """
332
+ self.vae.enable_slicing()
333
+
334
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
335
+ def disable_vae_slicing(self):
336
+ r"""
337
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
338
+ computing decoding in one step.
339
+ """
340
+ self.vae.disable_slicing()
341
+
342
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
343
+ def enable_vae_tiling(self):
344
+ r"""
345
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
346
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
347
+ processing larger images.
348
+ """
349
+ self.vae.enable_tiling()
350
+
351
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
352
+ def disable_vae_tiling(self):
353
+ r"""
354
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
355
+ computing decoding in one step.
356
+ """
357
+ self.vae.disable_tiling()
358
+
359
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
360
+ def encode_prompt(
361
+ self,
362
+ prompt: str,
363
+ prompt_2: Optional[str] = None,
364
+ device: Optional[torch.device] = None,
365
+ num_images_per_prompt: int = 1,
366
+ do_classifier_free_guidance: bool = True,
367
+ negative_prompt: Optional[str] = None,
368
+ negative_prompt_2: Optional[str] = None,
369
+ prompt_embeds: Optional[torch.FloatTensor] = None,
370
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
371
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
372
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
373
+ lora_scale: Optional[float] = None,
374
+ ):
375
+ r"""
376
+ Encodes the prompt into text encoder hidden states.
377
+
378
+ Args:
379
+ prompt (`str` or `List[str]`, *optional*):
380
+ prompt to be encoded
381
+ prompt_2 (`str` or `List[str]`, *optional*):
382
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
383
+ used in both text-encoders
384
+ device: (`torch.device`):
385
+ torch device
386
+ num_images_per_prompt (`int`):
387
+ number of images that should be generated per prompt
388
+ do_classifier_free_guidance (`bool`):
389
+ whether to use classifier free guidance or not
390
+ negative_prompt (`str` or `List[str]`, *optional*):
391
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
392
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
393
+ less than `1`).
394
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
395
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
396
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
397
+ prompt_embeds (`torch.FloatTensor`, *optional*):
398
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
399
+ provided, text embeddings will be generated from `prompt` input argument.
400
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
401
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
402
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
403
+ argument.
404
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
405
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
406
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
407
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
408
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
409
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
410
+ input argument.
411
+ lora_scale (`float`, *optional*):
412
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
413
+ """
414
+ device = device or self._execution_device
415
+
416
+ # set lora scale so that monkey patched LoRA
417
+ # function of text encoder can correctly access it
418
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
419
+ self._lora_scale = lora_scale
420
+
421
+ # dynamically adjust the LoRA scale
422
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
423
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
424
+
425
+ if prompt is not None and isinstance(prompt, str):
426
+ batch_size = 1
427
+ elif prompt is not None and isinstance(prompt, list):
428
+ batch_size = len(prompt)
429
+ else:
430
+ batch_size = prompt_embeds.shape[0]
431
+
432
+ # Define tokenizers and text encoders
433
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
434
+ text_encoders = (
435
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
436
+ )
437
+
438
+ if prompt_embeds is None:
439
+ prompt_2 = prompt_2 or prompt
440
+ # textual inversion: procecss multi-vector tokens if necessary
441
+ prompt_embeds_list = []
442
+ prompts = [prompt, prompt_2]
443
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
444
+ if isinstance(self, TextualInversionLoaderMixin):
445
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
446
+
447
+ text_inputs = tokenizer(
448
+ prompt,
449
+ padding="max_length",
450
+ max_length=tokenizer.model_max_length,
451
+ truncation=True,
452
+ return_tensors="pt",
453
+ )
454
+
455
+ text_input_ids = text_inputs.input_ids
456
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
457
+
458
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
459
+ text_input_ids, untruncated_ids
460
+ ):
461
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
462
+ logger.warning(
463
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
464
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
465
+ )
466
+
467
+ prompt_embeds = text_encoder(
468
+ text_input_ids.to(device),
469
+ output_hidden_states=True,
470
+ )
471
+
472
+ # We are only ALWAYS interested in the pooled output of the final text encoder
473
+ pooled_prompt_embeds = prompt_embeds[0]
474
+ prompt_embeds = prompt_embeds.hidden_states[-2]
475
+
476
+ prompt_embeds_list.append(prompt_embeds)
477
+
478
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
479
+
480
+ # get unconditional embeddings for classifier free guidance
481
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
482
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
483
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
484
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
485
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
486
+ negative_prompt = negative_prompt or ""
487
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
488
+
489
+ uncond_tokens: List[str]
490
+ if prompt is not None and type(prompt) is not type(negative_prompt):
491
+ raise TypeError(
492
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
493
+ f" {type(prompt)}."
494
+ )
495
+ elif isinstance(negative_prompt, str):
496
+ uncond_tokens = [negative_prompt, negative_prompt_2]
497
+ elif batch_size != len(negative_prompt):
498
+ raise ValueError(
499
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
500
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
501
+ " the batch size of `prompt`."
502
+ )
503
+ else:
504
+ uncond_tokens = [negative_prompt, negative_prompt_2]
505
+
506
+ negative_prompt_embeds_list = []
507
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
508
+ if isinstance(self, TextualInversionLoaderMixin):
509
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
510
+
511
+ max_length = prompt_embeds.shape[1]
512
+ uncond_input = tokenizer(
513
+ negative_prompt,
514
+ padding="max_length",
515
+ max_length=max_length,
516
+ truncation=True,
517
+ return_tensors="pt",
518
+ )
519
+
520
+ negative_prompt_embeds = text_encoder(
521
+ uncond_input.input_ids.to(device),
522
+ output_hidden_states=True,
523
+ )
524
+ # We are only ALWAYS interested in the pooled output of the final text encoder
525
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
526
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
527
+
528
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
529
+
530
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
531
+
532
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
533
+ bs_embed, seq_len, _ = prompt_embeds.shape
534
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
535
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
536
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
537
+
538
+ if do_classifier_free_guidance:
539
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
540
+ seq_len = negative_prompt_embeds.shape[1]
541
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
542
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
543
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
544
+
545
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
546
+ bs_embed * num_images_per_prompt, -1
547
+ )
548
+ if do_classifier_free_guidance:
549
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
550
+ bs_embed * num_images_per_prompt, -1
551
+ )
552
+
553
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
554
+
555
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
556
+ def prepare_extra_step_kwargs(self, generator, eta):
557
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
558
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
559
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
560
+ # and should be between [0, 1]
561
+
562
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
563
+ extra_step_kwargs = {}
564
+ if accepts_eta:
565
+ extra_step_kwargs["eta"] = eta
566
+
567
+ # check if the scheduler accepts generator
568
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
569
+ if accepts_generator:
570
+ extra_step_kwargs["generator"] = generator
571
+ return extra_step_kwargs
572
+
573
+ def check_inputs(
574
+ self,
575
+ prompt,
576
+ prompt_2,
577
+ height,
578
+ width,
579
+ strength,
580
+ callback_steps,
581
+ negative_prompt=None,
582
+ negative_prompt_2=None,
583
+ prompt_embeds=None,
584
+ negative_prompt_embeds=None,
585
+ ):
586
+ if strength < 0 or strength > 1:
587
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
588
+
589
+ if height % 8 != 0 or width % 8 != 0:
590
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
591
+
592
+ if (callback_steps is None) or (
593
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
594
+ ):
595
+ raise ValueError(
596
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
597
+ f" {type(callback_steps)}."
598
+ )
599
+
600
+ if prompt is not None and prompt_embeds is not None:
601
+ raise ValueError(
602
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
603
+ " only forward one of the two."
604
+ )
605
+ elif prompt_2 is not None and prompt_embeds is not None:
606
+ raise ValueError(
607
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
608
+ " only forward one of the two."
609
+ )
610
+ elif prompt is None and prompt_embeds is None:
611
+ raise ValueError(
612
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
613
+ )
614
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
615
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
616
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
617
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
618
+
619
+ if negative_prompt is not None and negative_prompt_embeds is not None:
620
+ raise ValueError(
621
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
622
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
623
+ )
624
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
625
+ raise ValueError(
626
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
627
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
628
+ )
629
+
630
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
631
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
632
+ raise ValueError(
633
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
634
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
635
+ f" {negative_prompt_embeds.shape}."
636
+ )
637
+
638
+ def prepare_latents(
639
+ self,
640
+ batch_size,
641
+ num_channels_latents,
642
+ height,
643
+ width,
644
+ dtype,
645
+ device,
646
+ generator,
647
+ latents=None,
648
+ image=None,
649
+ timestep=None,
650
+ is_strength_max=True,
651
+ add_noise=True,
652
+ return_noise=False,
653
+ return_image_latents=False,
654
+ ):
655
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
656
+ if isinstance(generator, list) and len(generator) != batch_size:
657
+ raise ValueError(
658
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
659
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
660
+ )
661
+
662
+ if (image is None or timestep is None) and not is_strength_max:
663
+ raise ValueError(
664
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
665
+ "However, either the image or the noise timestep has not been provided."
666
+ )
667
+
668
+ if image.shape[1] == 4:
669
+ image_latents = image.to(device=device, dtype=dtype)
670
+ elif return_image_latents or (latents is None and not is_strength_max):
671
+ image = image.to(device=device, dtype=dtype)
672
+ image_latents = self._encode_vae_image(image=image, generator=generator)
673
+
674
+ if latents is None and add_noise:
675
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
676
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
677
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
678
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
679
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
680
+ elif add_noise:
681
+ noise = latents.to(device)
682
+ latents = noise * self.scheduler.init_noise_sigma
683
+ else:
684
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
685
+ latents = image_latents.to(device)
686
+
687
+ outputs = (latents,)
688
+
689
+ if return_noise:
690
+ outputs += (noise,)
691
+
692
+ if return_image_latents:
693
+ outputs += (image_latents,)
694
+
695
+ return outputs
696
+
697
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
698
+ dtype = image.dtype
699
+ if self.vae.config.force_upcast:
700
+ image = image.float()
701
+ self.vae.to(dtype=torch.float32)
702
+
703
+ if isinstance(generator, list):
704
+ image_latents = [
705
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
706
+ for i in range(image.shape[0])
707
+ ]
708
+ image_latents = torch.cat(image_latents, dim=0)
709
+ else:
710
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
711
+
712
+ if self.vae.config.force_upcast:
713
+ self.vae.to(dtype)
714
+
715
+ image_latents = image_latents.to(dtype)
716
+ image_latents = self.vae.config.scaling_factor * image_latents
717
+
718
+ return image_latents
719
+
720
+ def prepare_mask_latents(
721
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
722
+ ):
723
+ # resize the mask to latents shape as we concatenate the mask to the latents
724
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
725
+ # and half precision
726
+ mask = torch.nn.functional.interpolate(
727
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
728
+ )
729
+ mask = mask.to(device=device, dtype=dtype)
730
+
731
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
732
+ if mask.shape[0] < batch_size:
733
+ if not batch_size % mask.shape[0] == 0:
734
+ raise ValueError(
735
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
736
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
737
+ " of masks that you pass is divisible by the total requested batch size."
738
+ )
739
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
740
+
741
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
742
+
743
+ if masked_image is not None and masked_image.shape[1] == 4:
744
+ masked_image_latents = masked_image
745
+ else:
746
+ masked_image_latents = None
747
+
748
+ if masked_image is not None:
749
+ if masked_image_latents is None:
750
+ masked_image = masked_image.to(device=device, dtype=dtype)
751
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
752
+
753
+ if masked_image_latents.shape[0] < batch_size:
754
+ if not batch_size % masked_image_latents.shape[0] == 0:
755
+ raise ValueError(
756
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
757
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
758
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
759
+ )
760
+ masked_image_latents = masked_image_latents.repeat(
761
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
762
+ )
763
+
764
+ masked_image_latents = (
765
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
766
+ )
767
+
768
+ # aligning device to prevent device errors when concating it with the latent model input
769
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
770
+
771
+ return mask, masked_image_latents
772
+
773
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
774
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
775
+ # get the original timestep using init_timestep
776
+ if denoising_start is None:
777
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
778
+ t_start = max(num_inference_steps - init_timestep, 0)
779
+ else:
780
+ t_start = 0
781
+
782
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
783
+
784
+ # Strength is irrelevant if we directly request a timestep to start at;
785
+ # that is, strength is determined by the denoising_start instead.
786
+ if denoising_start is not None:
787
+ discrete_timestep_cutoff = int(
788
+ round(
789
+ self.scheduler.config.num_train_timesteps
790
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
791
+ )
792
+ )
793
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
794
+ return torch.tensor(timesteps), len(timesteps)
795
+
796
+ return timesteps, num_inference_steps - t_start
797
+
798
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
799
+ def _get_add_time_ids(
800
+ self,
801
+ original_size,
802
+ crops_coords_top_left,
803
+ target_size,
804
+ aesthetic_score,
805
+ negative_aesthetic_score,
806
+ negative_original_size,
807
+ negative_crops_coords_top_left,
808
+ negative_target_size,
809
+ dtype,
810
+ ):
811
+ if self.config.requires_aesthetics_score:
812
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
813
+ add_neg_time_ids = list(
814
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
815
+ )
816
+ else:
817
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
818
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
819
+
820
+ passed_add_embed_dim = (
821
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
822
+ )
823
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
824
+
825
+ if (
826
+ expected_add_embed_dim > passed_add_embed_dim
827
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
828
+ ):
829
+ raise ValueError(
830
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
831
+ )
832
+ elif (
833
+ expected_add_embed_dim < passed_add_embed_dim
834
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
835
+ ):
836
+ raise ValueError(
837
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
838
+ )
839
+ elif expected_add_embed_dim != passed_add_embed_dim:
840
+ raise ValueError(
841
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
842
+ )
843
+
844
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
845
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
846
+
847
+ return add_time_ids, add_neg_time_ids
848
+
849
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
850
+ def upcast_vae(self):
851
+ dtype = self.vae.dtype
852
+ self.vae.to(dtype=torch.float32)
853
+ use_torch_2_0_or_xformers = isinstance(
854
+ self.vae.decoder.mid_block.attentions[0].processor,
855
+ (
856
+ AttnProcessor2_0,
857
+ XFormersAttnProcessor,
858
+ LoRAXFormersAttnProcessor,
859
+ LoRAAttnProcessor2_0,
860
+ ),
861
+ )
862
+ # if xformers or torch_2_0 is used attention block does not need
863
+ # to be in float32 which can save lots of memory
864
+ if use_torch_2_0_or_xformers:
865
+ self.vae.post_quant_conv.to(dtype)
866
+ self.vae.decoder.conv_in.to(dtype)
867
+ self.vae.decoder.mid_block.to(dtype)
868
+
869
+ @torch.no_grad()
870
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
871
+ def __call__(
872
+ self,
873
+ prompt: Union[str, List[str]] = None,
874
+ prompt_2: Optional[Union[str, List[str]]] = None,
875
+ image: PipelineImageInput = None,
876
+ base_image: PipelineImageInput = None,
877
+ mask_image: PipelineImageInput = None,
878
+ masked_image_latents: torch.FloatTensor = None,
879
+ height: Optional[int] = None,
880
+ width: Optional[int] = None,
881
+ strength: float = 0.9999,
882
+ num_inference_steps: int = 50,
883
+ denoising_start: Optional[float] = None,
884
+ denoising_end: Optional[float] = None,
885
+ guidance_scale: float = 7.5,
886
+ negative_prompt: Optional[Union[str, List[str]]] = None,
887
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
888
+ num_images_per_prompt: Optional[int] = 1,
889
+ eta: float = 0.0,
890
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
891
+ latents: Optional[torch.FloatTensor] = None,
892
+ prompt_embeds: Optional[torch.FloatTensor] = None,
893
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
894
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
895
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
896
+ output_type: Optional[str] = "pil",
897
+ return_dict: bool = True,
898
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
899
+ callback_steps: int = 1,
900
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
901
+ guidance_rescale: float = 0.0,
902
+ original_size: Tuple[int, int] = None,
903
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
904
+ target_size: Tuple[int, int] = None,
905
+ negative_original_size: Optional[Tuple[int, int]] = None,
906
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
907
+ negative_target_size: Optional[Tuple[int, int]] = None,
908
+ aesthetic_score: float = 6.0,
909
+ negative_aesthetic_score: float = 2.5,
910
+ ):
911
+ r"""
912
+ Function invoked when calling the pipeline for generation.
913
+
914
+ Args:
915
+ prompt (`str` or `List[str]`, *optional*):
916
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
917
+ instead.
918
+ prompt_2 (`str` or `List[str]`, *optional*):
919
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
920
+ used in both text-encoders
921
+ image (`PIL.Image.Image`):
922
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
923
+ be masked out with `mask_image` and repainted according to `prompt`.
924
+ mask_image (`PIL.Image.Image`):
925
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
926
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
927
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
928
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
929
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
930
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
931
+ Anything below 512 pixels won't work well for
932
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
933
+ and checkpoints that are not specifically fine-tuned on low resolutions.
934
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
935
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
936
+ Anything below 512 pixels won't work well for
937
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
938
+ and checkpoints that are not specifically fine-tuned on low resolutions.
939
+ strength (`float`, *optional*, defaults to 0.9999):
940
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
941
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
942
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
943
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
944
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
945
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
946
+ integer, the value of `strength` will be ignored.
947
+ num_inference_steps (`int`, *optional*, defaults to 50):
948
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
949
+ expense of slower inference.
950
+ denoising_start (`float`, *optional*):
951
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
952
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
953
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
954
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
955
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
956
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
957
+ denoising_end (`float`, *optional*):
958
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
959
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
960
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
961
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
962
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
963
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
964
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
965
+ guidance_scale (`float`, *optional*, defaults to 7.5):
966
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
967
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
968
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
969
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
970
+ usually at the expense of lower image quality.
971
+ negative_prompt (`str` or `List[str]`, *optional*):
972
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
973
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
974
+ less than `1`).
975
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
976
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
977
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
978
+ prompt_embeds (`torch.FloatTensor`, *optional*):
979
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
980
+ provided, text embeddings will be generated from `prompt` input argument.
981
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
982
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
983
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
984
+ argument.
985
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
986
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
987
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
988
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
989
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
990
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
991
+ input argument.
992
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
993
+ The number of images to generate per prompt.
994
+ eta (`float`, *optional*, defaults to 0.0):
995
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
996
+ [`schedulers.DDIMScheduler`], will be ignored for others.
997
+ generator (`torch.Generator`, *optional*):
998
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
999
+ to make generation deterministic.
1000
+ latents (`torch.FloatTensor`, *optional*):
1001
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1002
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1003
+ tensor will ge generated by sampling using the supplied random `generator`.
1004
+ output_type (`str`, *optional*, defaults to `"pil"`):
1005
+ The output format of the generate image. Choose between
1006
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1007
+ return_dict (`bool`, *optional*, defaults to `True`):
1008
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1009
+ plain tuple.
1010
+ callback (`Callable`, *optional*):
1011
+ A function that will be called every `callback_steps` steps during inference. The function will be
1012
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1013
+ callback_steps (`int`, *optional*, defaults to 1):
1014
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1015
+ called at every step.
1016
+ cross_attention_kwargs (`dict`, *optional*):
1017
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1018
+ `self.processor` in
1019
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1020
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1021
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1022
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
1023
+ explained in section 2.2 of
1024
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1025
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1026
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1027
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1028
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1029
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1030
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1031
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1032
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
1033
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1034
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1035
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1036
+ micro-conditioning as explained in section 2.2 of
1037
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1038
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1039
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1040
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1041
+ micro-conditioning as explained in section 2.2 of
1042
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1043
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1044
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1045
+ To negatively condition the generation process based on a target image resolution. It should be as same
1046
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1047
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1048
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1049
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1050
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1051
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1052
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1053
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1054
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1055
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1056
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1057
+
1058
+ Examples:
1059
+
1060
+ Returns:
1061
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1062
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1063
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1064
+ """
1065
+ # 0. Default height and width to unet
1066
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1067
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1068
+
1069
+ # 1. Check inputs
1070
+ self.check_inputs(
1071
+ prompt,
1072
+ prompt_2,
1073
+ height,
1074
+ width,
1075
+ strength,
1076
+ callback_steps,
1077
+ negative_prompt,
1078
+ negative_prompt_2,
1079
+ prompt_embeds,
1080
+ negative_prompt_embeds,
1081
+ )
1082
+
1083
+ # 2. Define call parameters
1084
+ if prompt is not None and isinstance(prompt, str):
1085
+ batch_size = 1
1086
+ elif prompt is not None and isinstance(prompt, list):
1087
+ batch_size = len(prompt)
1088
+ else:
1089
+ batch_size = prompt_embeds.shape[0]
1090
+
1091
+ device = self._execution_device
1092
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1093
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1094
+ # corresponds to doing no classifier free guidance.
1095
+ do_classifier_free_guidance = guidance_scale > 1.0
1096
+
1097
+ # 3. Encode input prompt
1098
+ text_encoder_lora_scale = (
1099
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1100
+ )
1101
+
1102
+ (
1103
+ prompt_embeds,
1104
+ negative_prompt_embeds,
1105
+ pooled_prompt_embeds,
1106
+ negative_pooled_prompt_embeds,
1107
+ ) = self.encode_prompt(
1108
+ prompt=prompt,
1109
+ prompt_2=prompt_2,
1110
+ device=device,
1111
+ num_images_per_prompt=num_images_per_prompt,
1112
+ do_classifier_free_guidance=do_classifier_free_guidance,
1113
+ negative_prompt=negative_prompt,
1114
+ negative_prompt_2=negative_prompt_2,
1115
+ prompt_embeds=prompt_embeds,
1116
+ negative_prompt_embeds=negative_prompt_embeds,
1117
+ pooled_prompt_embeds=pooled_prompt_embeds,
1118
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1119
+ lora_scale=text_encoder_lora_scale,
1120
+ )
1121
+
1122
+ # 4. set timesteps
1123
+ def denoising_value_valid(dnv):
1124
+ return isinstance(denoising_end, float) and 0 < dnv < 1
1125
+
1126
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1127
+ timesteps, num_inference_steps = self.get_timesteps(
1128
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
1129
+ )
1130
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1131
+ if num_inference_steps < 1:
1132
+ raise ValueError(
1133
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1134
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1135
+ )
1136
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1137
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1138
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1139
+ is_strength_max = strength == 1.0
1140
+
1141
+ # 5. Preprocess mask and image
1142
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
1143
+ init_image = init_image.to(dtype=torch.float32)
1144
+ init_base_image = self.image_processor.preprocess(base_image, height=height, width=width)
1145
+ init_base_image = init_base_image.to(dtype=torch.float32)
1146
+
1147
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
1148
+
1149
+ if masked_image_latents is not None:
1150
+ masked_image = masked_image_latents
1151
+ elif init_image.shape[1] == 4:
1152
+ # if images are in latent space, we can't mask it
1153
+ masked_image = None
1154
+ else:
1155
+ masked_image = init_image * (mask < 0.5)
1156
+
1157
+ # 6. Prepare latent variables
1158
+ num_channels_latents = self.vae.config.latent_channels
1159
+ num_channels_unet = self.unet.config.in_channels
1160
+ return_image_latents = num_channels_unet == 4
1161
+
1162
+ add_noise = True if denoising_start is None else False
1163
+ latents_outputs = self.prepare_latents(
1164
+ batch_size * num_images_per_prompt,
1165
+ num_channels_latents,
1166
+ height,
1167
+ width,
1168
+ prompt_embeds.dtype,
1169
+ device,
1170
+ generator,
1171
+ latents,
1172
+ image=init_base_image,
1173
+ timestep=latent_timestep,
1174
+ is_strength_max=is_strength_max,
1175
+ add_noise=add_noise,
1176
+ return_noise=True,
1177
+ return_image_latents=return_image_latents,
1178
+ )
1179
+
1180
+ if return_image_latents:
1181
+ latents, noise, image_latents = latents_outputs
1182
+ else:
1183
+ latents, noise = latents_outputs
1184
+
1185
+ # 7. Prepare mask latent variables
1186
+ mask, masked_image_latents = self.prepare_mask_latents(
1187
+ mask,
1188
+ masked_image,
1189
+ batch_size * num_images_per_prompt,
1190
+ height,
1191
+ width,
1192
+ prompt_embeds.dtype,
1193
+ device,
1194
+ generator,
1195
+ do_classifier_free_guidance,
1196
+ )
1197
+
1198
+ # 8. Check that sizes of mask, masked image and latents match
1199
+ if num_channels_unet == 9:
1200
+ # default case for runwayml/stable-diffusion-inpainting
1201
+ num_channels_mask = mask.shape[1]
1202
+ num_channels_masked_image = masked_image_latents.shape[1]
1203
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1204
+ raise ValueError(
1205
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1206
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1207
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1208
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1209
+ " `pipeline.unet` or your `mask_image` or `image` input."
1210
+ )
1211
+ elif num_channels_unet != 4:
1212
+ raise ValueError(
1213
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1214
+ )
1215
+ # 8.1 Prepare extra step kwargs.
1216
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1217
+
1218
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1219
+ height, width = latents.shape[-2:]
1220
+ height = height * self.vae_scale_factor
1221
+ width = width * self.vae_scale_factor
1222
+
1223
+ original_size = original_size or (height, width)
1224
+ target_size = target_size or (height, width)
1225
+
1226
+ # 10. Prepare added time ids & embeddings
1227
+ if negative_original_size is None:
1228
+ negative_original_size = original_size
1229
+ if negative_target_size is None:
1230
+ negative_target_size = target_size
1231
+
1232
+ add_text_embeds = pooled_prompt_embeds
1233
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1234
+ original_size,
1235
+ crops_coords_top_left,
1236
+ target_size,
1237
+ aesthetic_score,
1238
+ negative_aesthetic_score,
1239
+ negative_original_size,
1240
+ negative_crops_coords_top_left,
1241
+ negative_target_size,
1242
+ dtype=prompt_embeds.dtype,
1243
+ )
1244
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1245
+
1246
+ if do_classifier_free_guidance:
1247
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1248
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1249
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1250
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1251
+
1252
+ prompt_embeds = prompt_embeds.to(device)
1253
+ add_text_embeds = add_text_embeds.to(device)
1254
+ add_time_ids = add_time_ids.to(device)
1255
+
1256
+ # 11. Denoising loop
1257
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1258
+
1259
+ if (
1260
+ denoising_end is not None
1261
+ and denoising_start is not None
1262
+ and denoising_value_valid(denoising_end)
1263
+ and denoising_value_valid(denoising_start)
1264
+ and denoising_start >= denoising_end
1265
+ ):
1266
+ raise ValueError(
1267
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
1268
+ + f" {denoising_end} when using type float."
1269
+ )
1270
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
1271
+ discrete_timestep_cutoff = int(
1272
+ round(
1273
+ self.scheduler.config.num_train_timesteps
1274
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
1275
+ )
1276
+ )
1277
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1278
+ timesteps = timesteps[:num_inference_steps]
1279
+
1280
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1281
+ for i, t in enumerate(timesteps):
1282
+ # expand the latents if we are doing classifier free guidance
1283
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1284
+
1285
+ # concat latents, mask, masked_image_latents in the channel dimension
1286
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1287
+
1288
+ if num_channels_unet == 9:
1289
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1290
+
1291
+ # predict the noise residual
1292
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1293
+ noise_pred = self.unet(
1294
+ latent_model_input,
1295
+ t,
1296
+ encoder_hidden_states=prompt_embeds,
1297
+ cross_attention_kwargs=cross_attention_kwargs,
1298
+ added_cond_kwargs=added_cond_kwargs,
1299
+ return_dict=False,
1300
+ )[0]
1301
+
1302
+ # perform guidance
1303
+ if do_classifier_free_guidance:
1304
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1305
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1306
+
1307
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1308
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1309
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1310
+
1311
+ # compute the previous noisy sample x_t -> x_t-1
1312
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1313
+
1314
+ if num_channels_unet == 4:
1315
+ init_latents_proper = image_latents[:1]
1316
+ init_mask = mask[:1]
1317
+
1318
+ if i < len(timesteps) - 1:
1319
+ noise_timestep = timesteps[i + 1]
1320
+ init_latents_proper = self.scheduler.add_noise(
1321
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1322
+ )
1323
+
1324
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1325
+
1326
+ # call the callback, if provided
1327
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1328
+ progress_bar.update()
1329
+ if callback is not None and i % callback_steps == 0:
1330
+ callback(i, t, latents)
1331
+
1332
+ if not output_type == "latent":
1333
+ # make sure the VAE is in float32 mode, as it overflows in float16
1334
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1335
+
1336
+ if needs_upcasting:
1337
+ self.upcast_vae()
1338
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1339
+
1340
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1341
+
1342
+ # cast back to fp16 if needed
1343
+ if needs_upcasting:
1344
+ self.vae.to(dtype=torch.float16)
1345
+ else:
1346
+ return StableDiffusionXLPipelineOutput(images=latents)
1347
+
1348
+ # apply watermark if available
1349
+ if self.watermark is not None:
1350
+ image = self.watermark.apply_watermark(image)
1351
+
1352
+ image = self.image_processor.postprocess(image, output_type=output_type)
1353
+
1354
+ # Offload all models
1355
+ self.maybe_free_model_hooks()
1356
+
1357
+ if not return_dict:
1358
+ return (image,)
1359
+
1360
+ return StableDiffusionXLPipelineOutput(images=image)
1361
+
1362
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1363
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
1364
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1365
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1366
+ # it here explicitly to be able to tell that it's coming from an SDXL
1367
+ # pipeline.
1368
+
1369
+ # Remove any existing hooks.
1370
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1371
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
1372
+ else:
1373
+ raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
1374
+
1375
+ is_model_cpu_offload = False
1376
+ is_sequential_cpu_offload = False
1377
+ recursive = False
1378
+ for _, component in self.components.items():
1379
+ if isinstance(component, torch.nn.Module):
1380
+ if hasattr(component, "_hf_hook"):
1381
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1382
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1383
+ logger.info(
1384
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1385
+ )
1386
+ recursive = is_sequential_cpu_offload
1387
+ remove_hook_from_module(component, recurse=recursive)
1388
+ state_dict, network_alphas = self.lora_state_dict(
1389
+ pretrained_model_name_or_path_or_dict,
1390
+ unet_config=self.unet.config,
1391
+ **kwargs,
1392
+ )
1393
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1394
+
1395
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1396
+ if len(text_encoder_state_dict) > 0:
1397
+ self.load_lora_into_text_encoder(
1398
+ text_encoder_state_dict,
1399
+ network_alphas=network_alphas,
1400
+ text_encoder=self.text_encoder,
1401
+ prefix="text_encoder",
1402
+ lora_scale=self.lora_scale,
1403
+ )
1404
+
1405
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1406
+ if len(text_encoder_2_state_dict) > 0:
1407
+ self.load_lora_into_text_encoder(
1408
+ text_encoder_2_state_dict,
1409
+ network_alphas=network_alphas,
1410
+ text_encoder=self.text_encoder_2,
1411
+ prefix="text_encoder_2",
1412
+ lora_scale=self.lora_scale,
1413
+ )
1414
+
1415
+ # Offload back.
1416
+ if is_model_cpu_offload:
1417
+ self.enable_model_cpu_offload()
1418
+ elif is_sequential_cpu_offload:
1419
+ self.enable_sequential_cpu_offload()
1420
+
1421
+ @classmethod
1422
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
1423
+ def save_lora_weights(
1424
+ self,
1425
+ save_directory: Union[str, os.PathLike],
1426
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1427
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1428
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1429
+ is_main_process: bool = True,
1430
+ weight_name: str = None,
1431
+ save_function: Callable = None,
1432
+ safe_serialization: bool = True,
1433
+ ):
1434
+ state_dict = {}
1435
+
1436
+ def pack_weights(layers, prefix):
1437
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1438
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1439
+ return layers_state_dict
1440
+
1441
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1442
+ raise ValueError(
1443
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
1444
+ )
1445
+
1446
+ if unet_lora_layers:
1447
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1448
+
1449
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1450
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1451
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1452
+
1453
+ self.write_lora_layers(
1454
+ state_dict=state_dict,
1455
+ save_directory=save_directory,
1456
+ is_main_process=is_main_process,
1457
+ weight_name=weight_name,
1458
+ save_function=save_function,
1459
+ safe_serialization=safe_serialization,
1460
+ )
1461
+
1462
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
1463
+ def _remove_text_encoder_monkey_patch(self):
1464
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1465
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)