jiuntian commited on
Commit
896aaac
·
verified ·
1 Parent(s): f5b9ce3

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd
3
+ ---
4
+
5
+ # InteractDiffusion Diffuser Implementation
6
+
7
+ [Project Page](https://jiuntian.github.io/interactdiffusion) |
8
+ [Paper](https://arxiv.org/abs/2312.05849) |
9
+ [WebUI](https://github.com/jiuntian/sd-webui-interactdiffusion) |
10
+ [Demo](https://huggingface.co/spaces/interactdiffusion/interactdiffusion) |
11
+ [Video](https://www.youtube.com/watch?v=Uunzufq8m6Y) |
12
+ [Diffuser](https://huggingface.co/interactdiffusion/diffusers-v1-2) |
13
+ [Colab](https://colab.research.google.com/drive/1Bh9PjfTylxI2rbME5mQJtFqNTGvaghJq?usp=sharing)
14
+
15
+ ## How to Use
16
+
17
+ ```python
18
+ from diffusers import DiffusionPipeline
19
+ import torch
20
+
21
+ pipeline = DiffusionPipeline.from_pretrained(
22
+ "jiuntian/interactdiffusion-xl-1024",
23
+ trust_remote_code=True,
24
+ variant="fp16", torch_dtype=torch.float16
25
+ )
26
+ pipeline = pipeline.to("cuda")
27
+
28
+ images = pipeline(
29
+ prompt="a person is feeding a cat",
30
+ interactdiffusion_subject_phrases=["person"],
31
+ interactdiffusion_object_phrases=["cat"],
32
+ interactdiffusion_action_phrases=["feeding"],
33
+ interactdiffusion_subject_boxes=[[0.0332, 0.1660, 0.3359, 0.7305]],
34
+ interactdiffusion_object_boxes=[[0.2891, 0.4766, 0.6680, 0.7930]],
35
+ interactdiffusion_scheduled_sampling_beta=1,
36
+ output_type="pil",
37
+ num_inference_steps=50,
38
+ ).images
39
+
40
+ images[0].save('out.jpg')
41
+ ```
42
+
43
+ For more information, please check the [project homepage](https://jiuntian.github.io/interactdiffusion/).
44
+
45
+ ## Citation
46
+
47
+ ```bibtex
48
+ @inproceedings{hoe2023interactdiffusion,
49
+ title={InteractDiffusion: Interaction Control in Text-to-Image Diffusion Models},
50
+ author={Jiun Tian Hoe and Xudong Jiang and Chee Seng Chan and Yap-Peng Tan and Weipeng Hu},
51
+ year={2024},
52
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
53
+ }
54
+ ```
55
+
56
+ ## Acknowledgement
57
+
58
+ This work is developed based on the codebase of [GLIGEN](https://github.com/gligen/GLIGEN) and [LDM](https://github.com/CompVis/latent-diffusion).
model_index.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline_interactdiffusion_sdxl", "StableDiffusionXLInteractDiffusionPipeline"],
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "jiuntian/interactdiffusion-xl-1024",
5
+ "feature_extractor": [
6
+ null,
7
+ null
8
+ ],
9
+ "force_zeros_for_empty_prompt": true,
10
+ "image_encoder": [
11
+ null,
12
+ null
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "EulerDiscreteScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "text_encoder_2": [
23
+ "transformers",
24
+ "CLIPTextModelWithProjection"
25
+ ],
26
+ "tokenizer": [
27
+ "transformers",
28
+ "CLIPTokenizer"
29
+ ],
30
+ "tokenizer_2": [
31
+ "transformers",
32
+ "CLIPTokenizer"
33
+ ],
34
+ "unet": [
35
+ "interactdiffusion_unet_2d_condition",
36
+ "InteractDiffusionUNet2DConditionModel"
37
+ ],
38
+ "vae": [
39
+ "diffusers",
40
+ "AutoencoderKL"
41
+ ]
42
+ }
pipeline_interactdiffusion_sdxl.py ADDED
@@ -0,0 +1,1417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ import warnings
4
+
5
+ import PIL.Image
6
+ import torch
7
+
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ CLIPTokenizer,
13
+ CLIPVisionModelWithProjection,
14
+ )
15
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
16
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
17
+ from diffusers.loaders import (
18
+ FromSingleFileMixin,
19
+ IPAdapterMixin,
20
+ StableDiffusionXLLoraLoaderMixin,
21
+ TextualInversionLoaderMixin,
22
+ )
23
+ from diffusers.models import (
24
+ AutoencoderKL,
25
+ ImageProjection,
26
+ UNet2DConditionModel
27
+ )
28
+ from diffusers.models.attention import GatedSelfAttentionDense
29
+ from diffusers.models.attention_processor import (
30
+ AttnProcessor2_0,
31
+ FusedAttnProcessor2_0,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.pipelines.pipeline_utils import (
36
+ DiffusionPipeline,
37
+ StableDiffusionMixin
38
+ )
39
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg, retrieve_timesteps
40
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
41
+ from diffusers.schedulers import (
42
+ KarrasDiffusionSchedulers
43
+ )
44
+ from diffusers.utils import (
45
+ USE_PEFT_BACKEND,
46
+ deprecate,
47
+ is_invisible_watermark_available,
48
+ is_torch_xla_available,
49
+ logging,
50
+ replace_example_docstring,
51
+ scale_lora_layers,
52
+ unscale_lora_layers,
53
+ )
54
+ from diffusers.utils.torch_utils import randn_tensor
55
+
56
+ from interactdiffusion_xl.interactdiffusion_unet_2d_condition import InteractDiffusionUNet2DConditionModel
57
+
58
+
59
+ if is_invisible_watermark_available():
60
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
61
+
62
+ if is_torch_xla_available():
63
+ import torch_xla.core.xla_model as xm # type: ignore
64
+
65
+ XLA_AVAILABLE = True
66
+ else:
67
+ XLA_AVAILABLE = False
68
+
69
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
+
71
+ EXAMPLE_DOC_STRING = """
72
+ Examples:
73
+ ```py
74
+ >>> import torch
75
+ >>> from pipeline_gligen_sdxl import StableDiffusionXLGLIGENPipeline
76
+
77
+ >>> pipe = StableDiffusionXLGLIGENPipeline.from_pretrained(
78
+ ... "xxx", torch_dtype=torch.float16
79
+ ... )
80
+ >>> pipe = pipe.to("cuda")
81
+
82
+ >>> prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage"
83
+ >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
84
+ >>> phrases = ["a waterfall", "a modern high speed train running through the tunnel"]
85
+
86
+ >>> images = pipe(
87
+ ... prompt=prompt,
88
+ ... gligen_phrases=phrases,
89
+ ... gligen_boxes=boxes,
90
+ ... gligen_scheduled_sampling_beta=1,
91
+ ... output_type="pil",
92
+ ... num_inference_steps=50,
93
+ ... ).images
94
+
95
+ >>> images[0].save("./gligen-xl-generation-text-box.jpg")
96
+ ```
97
+ """
98
+
99
+ class StableDiffusionXLInteractDiffusionPipeline(
100
+ DiffusionPipeline,
101
+ StableDiffusionMixin,
102
+ FromSingleFileMixin,
103
+ StableDiffusionXLLoraLoaderMixin,
104
+ TextualInversionLoaderMixin,
105
+ IPAdapterMixin,
106
+ ):
107
+ r"""
108
+ Pipeline for GLIGEN layout text-to-image generation using Stable Diffusion XL.
109
+ """
110
+
111
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
112
+ _optional_components = [
113
+ "tokenizer",
114
+ "tokenizer_2",
115
+ "text_encoder",
116
+ "text_encoder_2",
117
+ "image_encoder",
118
+ "feature_extractor",
119
+ ]
120
+ _callback_tensor_inputs = [
121
+ "latents",
122
+ "prompt_embeds",
123
+ "negative_prompt_embeds",
124
+ "add_text_embeds",
125
+ "add_time_ids",
126
+ "negative_pooled_prompt_embeds",
127
+ "negative_add_time_ids",
128
+ ]
129
+
130
+
131
+ def __init__(
132
+ self,
133
+ vae: AutoencoderKL,
134
+ text_encoder: CLIPTextModel,
135
+ text_encoder_2: CLIPTextModelWithProjection,
136
+ tokenizer: CLIPTokenizer,
137
+ tokenizer_2: CLIPTokenizer,
138
+ unet: InteractDiffusionUNet2DConditionModel,
139
+ scheduler: KarrasDiffusionSchedulers,
140
+ image_encoder: CLIPVisionModelWithProjection = None,
141
+ feature_extractor: CLIPImageProcessor = None,
142
+ force_zeros_for_empty_prompt: bool = True,
143
+ add_watermarker: Optional[bool] = None,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.register_modules(
148
+ vae=vae,
149
+ text_encoder=text_encoder,
150
+ text_encoder_2=text_encoder_2,
151
+ tokenizer=tokenizer,
152
+ tokenizer_2=tokenizer_2,
153
+ unet=unet,
154
+ scheduler=scheduler,
155
+ image_encoder=image_encoder,
156
+ feature_extractor=feature_extractor,
157
+ )
158
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
159
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
160
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
161
+
162
+ self.default_sample_size = self.unet.config.sample_size
163
+
164
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
165
+
166
+ if add_watermarker:
167
+ self.watermark = StableDiffusionXLWatermarker()
168
+ else:
169
+ self.watermark = None
170
+ # copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.encode_prompt
171
+ def encode_prompt(
172
+ self,
173
+ prompt: str,
174
+ prompt_2: Optional[str] = None,
175
+ device: Optional[torch.device] = None,
176
+ num_images_per_prompt: int = 1,
177
+ do_classifier_free_guidance: bool = True,
178
+ negative_prompt: Optional[str] = None,
179
+ negative_prompt_2: Optional[str] = None,
180
+ prompt_embeds: Optional[torch.Tensor] = None,
181
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
182
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
183
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
184
+ lora_scale: Optional[float] = None,
185
+ clip_skip: Optional[int] = None,
186
+ ):
187
+ device = device or self._execution_device
188
+
189
+ # set lora scale so that monkey patched LoRA
190
+ # function of text encoder can correctly access it
191
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
192
+ self._lora_scale = lora_scale
193
+
194
+ # dynamically adjust the LoRA scale
195
+ if self.text_encoder is not None:
196
+ if not USE_PEFT_BACKEND:
197
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
198
+ else:
199
+ scale_lora_layers(self.text_encoder, lora_scale)
200
+
201
+ if self.text_encoder_2 is not None:
202
+ if not USE_PEFT_BACKEND:
203
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
204
+ else:
205
+ scale_lora_layers(self.text_encoder_2, lora_scale)
206
+
207
+ prompt = [prompt] if isinstance(prompt, str) else prompt
208
+
209
+ if prompt is not None:
210
+ batch_size = len(prompt)
211
+ else:
212
+ batch_size = prompt_embeds.shape[0]
213
+
214
+ # Define tokenizers and text encoders
215
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
216
+ text_encoders = (
217
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
218
+ )
219
+
220
+ if prompt_embeds is None:
221
+ prompt_2 = prompt_2 or prompt
222
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
223
+
224
+ # textual inversion: process multi-vector tokens if necessary
225
+ prompt_embeds_list = []
226
+ prompts = [prompt, prompt_2]
227
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
228
+ if isinstance(self, TextualInversionLoaderMixin):
229
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
230
+
231
+ text_inputs = tokenizer(
232
+ prompt,
233
+ padding="max_length",
234
+ max_length=tokenizer.model_max_length,
235
+ truncation=True,
236
+ return_tensors="pt",
237
+ )
238
+
239
+ text_input_ids = text_inputs.input_ids
240
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
241
+
242
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
243
+ text_input_ids, untruncated_ids
244
+ ):
245
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
252
+
253
+ # We are only ALWAYS interested in the pooled output of the final text encoder
254
+ pooled_prompt_embeds = prompt_embeds[0]
255
+ if clip_skip is None:
256
+ prompt_embeds = prompt_embeds.hidden_states[-2]
257
+ else:
258
+ # "2" because SDXL always indexes from the penultimate layer.
259
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
260
+
261
+ prompt_embeds_list.append(prompt_embeds)
262
+
263
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
264
+
265
+ # get unconditional embeddings for classifier free guidance
266
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
267
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
268
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
269
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
270
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
271
+ negative_prompt = negative_prompt or ""
272
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
273
+
274
+ # normalize str to list
275
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
276
+ negative_prompt_2 = (
277
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
278
+ )
279
+
280
+ uncond_tokens: List[str]
281
+ if prompt is not None and type(prompt) is not type(negative_prompt):
282
+ raise TypeError(
283
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
+ f" {type(prompt)}."
285
+ )
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = [negative_prompt, negative_prompt_2]
294
+
295
+ negative_prompt_embeds_list = []
296
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
297
+ if isinstance(self, TextualInversionLoaderMixin):
298
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
299
+
300
+ max_length = prompt_embeds.shape[1]
301
+ uncond_input = tokenizer(
302
+ negative_prompt,
303
+ padding="max_length",
304
+ max_length=max_length,
305
+ truncation=True,
306
+ return_tensors="pt",
307
+ )
308
+
309
+ negative_prompt_embeds = text_encoder(
310
+ uncond_input.input_ids.to(device),
311
+ output_hidden_states=True,
312
+ )
313
+ # We are only ALWAYS interested in the pooled output of the final text encoder
314
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
315
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
316
+
317
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
318
+
319
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
320
+
321
+ if self.text_encoder_2 is not None:
322
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
323
+ else:
324
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
325
+
326
+ bs_embed, seq_len, _ = prompt_embeds.shape
327
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
328
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
329
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
330
+
331
+ if do_classifier_free_guidance:
332
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
333
+ seq_len = negative_prompt_embeds.shape[1]
334
+
335
+ if self.text_encoder_2 is not None:
336
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
337
+ else:
338
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
339
+
340
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
342
+
343
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
344
+ bs_embed * num_images_per_prompt, -1
345
+ )
346
+ if do_classifier_free_guidance:
347
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
348
+ bs_embed * num_images_per_prompt, -1
349
+ )
350
+
351
+ if self.text_encoder is not None:
352
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
353
+ # Retrieve the original scale by scaling back the LoRA layers
354
+ unscale_lora_layers(self.text_encoder, lora_scale)
355
+
356
+ if self.text_encoder_2 is not None:
357
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
358
+ # Retrieve the original scale by scaling back the LoRA layers
359
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
360
+
361
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
362
+
363
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
364
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
365
+ dtype = next(self.image_encoder.parameters()).dtype
366
+
367
+ if not isinstance(image, torch.Tensor):
368
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
369
+
370
+ image = image.to(device=device, dtype=dtype)
371
+ if output_hidden_states:
372
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
373
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
374
+ uncond_image_enc_hidden_states = self.image_encoder(
375
+ torch.zeros_like(image), output_hidden_states=True
376
+ ).hidden_states[-2]
377
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
378
+ num_images_per_prompt, dim=0
379
+ )
380
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
381
+ else:
382
+ image_embeds = self.image_encoder(image).image_embeds
383
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
384
+ uncond_image_embeds = torch.zeros_like(image_embeds)
385
+
386
+ return image_embeds, uncond_image_embeds
387
+
388
+ def encode_prompt_gligen(
389
+ self,
390
+ prompt: str,
391
+ prompt_2: Optional[str] = None,
392
+ device: Optional[torch.device] = None,
393
+ num_images_per_prompt: int = 1,
394
+ gligen_embeds: Optional[torch.Tensor] = None,
395
+ lora_scale: Optional[float] = None,
396
+ clip_skip: Optional[int] = None,
397
+ ):
398
+ device = device or self._execution_device
399
+
400
+ # set lora scale so that monkey patched LoRA
401
+ # function of text encoder can correctly access it
402
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
403
+ self._lora_scale = lora_scale
404
+
405
+ # dynamically adjust the LoRA scale
406
+ if self.text_encoder is not None:
407
+ if not USE_PEFT_BACKEND:
408
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
409
+ else:
410
+ scale_lora_layers(self.text_encoder, lora_scale)
411
+
412
+ if self.text_encoder_2 is not None:
413
+ if not USE_PEFT_BACKEND:
414
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
415
+ else:
416
+ scale_lora_layers(self.text_encoder_2, lora_scale)
417
+
418
+ prompt = [prompt] if isinstance(prompt, str) else prompt
419
+
420
+ if prompt is not None:
421
+ batch_size = len(prompt)
422
+ else:
423
+ batch_size = prompt_embeds.shape[0]
424
+
425
+ # Define tokenizers and text encoders
426
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
427
+ text_encoders = (
428
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
429
+ )
430
+
431
+ if gligen_embeds is None:
432
+ prompt_2 = prompt_2 or prompt
433
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
434
+
435
+ # textual inversion: process multi-vector tokens if necessary
436
+ gligen_embeds_list = []
437
+ prompts = [prompt, prompt_2]
438
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
439
+ if isinstance(self, TextualInversionLoaderMixin):
440
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
441
+
442
+ text_inputs = tokenizer(
443
+ prompt,
444
+ padding="max_length",
445
+ max_length=tokenizer.model_max_length,
446
+ truncation=True,
447
+ return_tensors="pt",
448
+ )
449
+
450
+ text_input_ids = text_inputs.input_ids
451
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
452
+
453
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
454
+ text_input_ids, untruncated_ids
455
+ ):
456
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
457
+ logger.warning(
458
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
459
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
460
+ )
461
+
462
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
463
+
464
+ if isinstance(text_encoder, CLIPTextModel):
465
+ gligen_embeds_list.append(prompt_embeds.pooler_output)
466
+ elif isinstance(text_encoder, CLIPTextModelWithProjection):
467
+ gligen_embeds_list.append(prompt_embeds.text_embeds)
468
+
469
+ gligen_embeds = torch.concat(gligen_embeds_list, dim=-1)
470
+
471
+ if self.text_encoder is not None:
472
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
473
+ # Retrieve the original scale by scaling back the LoRA layers
474
+ unscale_lora_layers(self.text_encoder, lora_scale)
475
+
476
+ if self.text_encoder_2 is not None:
477
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
478
+ # Retrieve the original scale by scaling back the LoRA layers
479
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
480
+
481
+ return gligen_embeds
482
+
483
+ # Copied from SDXL
484
+ def prepare_ip_adapter_image_embeds(
485
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
486
+ ):
487
+ image_embeds = []
488
+ if do_classifier_free_guidance:
489
+ negative_image_embeds = []
490
+ if ip_adapter_image_embeds is None:
491
+ if not isinstance(ip_adapter_image, list):
492
+ ip_adapter_image = [ip_adapter_image]
493
+
494
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
495
+ raise ValueError(
496
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
497
+ )
498
+
499
+ for single_ip_adapter_image, image_proj_layer in zip(
500
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
501
+ ):
502
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
503
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
504
+ single_ip_adapter_image, device, 1, output_hidden_state
505
+ )
506
+
507
+ image_embeds.append(single_image_embeds[None, :])
508
+ if do_classifier_free_guidance:
509
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
510
+ else:
511
+ for single_image_embeds in ip_adapter_image_embeds:
512
+ if do_classifier_free_guidance:
513
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
514
+ negative_image_embeds.append(single_negative_image_embeds)
515
+ image_embeds.append(single_image_embeds)
516
+
517
+ ip_adapter_image_embeds = []
518
+ for i, single_image_embeds in enumerate(image_embeds):
519
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
520
+ if do_classifier_free_guidance:
521
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
522
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
523
+
524
+ single_image_embeds = single_image_embeds.to(device=device)
525
+ ip_adapter_image_embeds.append(single_image_embeds)
526
+
527
+ return ip_adapter_image_embeds
528
+
529
+ # Copied form SDXL
530
+ def prepare_extra_step_kwargs(self, generator, eta):
531
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
532
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
533
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
534
+ # and should be between [0, 1]
535
+
536
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
537
+ extra_step_kwargs = {}
538
+ if accepts_eta:
539
+ extra_step_kwargs["eta"] = eta
540
+
541
+ # check if the scheduler accepts generator
542
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
543
+ if accepts_generator:
544
+ extra_step_kwargs["generator"] = generator
545
+ return extra_step_kwargs
546
+
547
+ # Copied from SDXL and StableDiffusionGLIGENPipeline
548
+ def check_inputs(
549
+ self,
550
+ prompt,
551
+ prompt_2,
552
+ height,
553
+ width,
554
+ callback_steps,
555
+ interactdiffusion_subject_phrases,
556
+ interactdiffusion_subject_boxes,
557
+ interactdiffusion_object_phrases,
558
+ interactdiffusion_object_boxes,
559
+ interactdiffusion_action_phrases,
560
+ negative_prompt=None,
561
+ negative_prompt_2=None,
562
+ prompt_embeds=None,
563
+ negative_prompt_embeds=None,
564
+ pooled_prompt_embeds=None,
565
+ negative_pooled_prompt_embeds=None,
566
+ ip_adapter_image=None,
567
+ ip_adapter_image_embeds=None,
568
+ callback_on_step_end_tensor_inputs=None,
569
+ ):
570
+ if height % 8 != 0 or width % 8 != 0:
571
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
572
+
573
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
574
+ raise ValueError(
575
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
576
+ f" {type(callback_steps)}."
577
+ )
578
+
579
+ if callback_on_step_end_tensor_inputs is not None and not all(
580
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
581
+ ):
582
+ raise ValueError(
583
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
584
+ )
585
+
586
+ if prompt is not None and prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
589
+ " only forward one of the two."
590
+ )
591
+ elif prompt_2 is not None and prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
594
+ " only forward one of the two."
595
+ )
596
+ elif prompt is None and prompt_embeds is None:
597
+ raise ValueError(
598
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
599
+ )
600
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
601
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
602
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
603
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
604
+
605
+ if negative_prompt is not None and negative_prompt_embeds is not None:
606
+ raise ValueError(
607
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
608
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
609
+ )
610
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
611
+ raise ValueError(
612
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
613
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
614
+ )
615
+
616
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
617
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
618
+ raise ValueError(
619
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
620
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
621
+ f" {negative_prompt_embeds.shape}."
622
+ )
623
+
624
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
625
+ raise ValueError(
626
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
627
+ )
628
+
629
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
630
+ raise ValueError(
631
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
632
+ )
633
+
634
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
635
+ raise ValueError(
636
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
637
+ )
638
+
639
+ if ip_adapter_image_embeds is not None:
640
+ if not isinstance(ip_adapter_image_embeds, list):
641
+ raise ValueError(
642
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
643
+ )
644
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
645
+ raise ValueError(
646
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
647
+ )
648
+
649
+ if len(interactdiffusion_subject_phrases) == len(interactdiffusion_subject_boxes) == len(interactdiffusion_object_phrases) == len(interactdiffusion_object_boxes) == len(interactdiffusion_action_phrases):
650
+ ValueError(
651
+ "length of `interactdiffusion_subject_phrases`, `interactdiffusion_subject_boxes`, `interactdiffusion_object_phrases`, "
652
+ "`interactdiffusion_object_boxes`, and `interactdiffusion_action_phrases` has to be same, but"
653
+ f" got: `interactdiffusion_subject_phrases` {len(interactdiffusion_subject_phrases)},"
654
+ f"`interactdiffusion_subject_boxes` {len(interactdiffusion_subject_boxes)}"
655
+ f"`interactdiffusion_object_phrases` {len(interactdiffusion_object_phrases)}"
656
+ f"`interactdiffusion_object_boxes` {len(interactdiffusion_object_boxes)}"
657
+ f"`interactdiffusion_action_phrases` {len(interactdiffusion_action_phrases)}"
658
+ )
659
+
660
+ # Copied from SDXL
661
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
662
+ shape = (
663
+ batch_size,
664
+ num_channels_latents,
665
+ int(height) // self.vae_scale_factor,
666
+ int(width) // self.vae_scale_factor,
667
+ )
668
+ if isinstance(generator, list) and len(generator) != batch_size:
669
+ raise ValueError(
670
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
671
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
672
+ )
673
+
674
+ if latents is None:
675
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
676
+ else:
677
+ latents = latents.to(device)
678
+
679
+ # scale the initial noise by the standard deviation required by the scheduler
680
+ latents = latents * self.scheduler.init_noise_sigma
681
+ return latents
682
+
683
+ # Copied from SDGligenPipeline
684
+ def enable_fuser(self, enabled=True):
685
+ for module in self.unet.modules():
686
+ if type(module) is GatedSelfAttentionDense:
687
+ module.enabled = enabled
688
+
689
+ # Copied from SDGligenPipeline
690
+ def draw_inpaint_mask_from_boxes(self, boxes, size):
691
+ inpaint_mask = torch.ones(size[0], size[1])
692
+ for box in boxes:
693
+ x0, x1 = box[0] * size[0], box[2] * size[0]
694
+ y0, y1 = box[1] * size[1], box[3] * size[1]
695
+ inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0
696
+ return inpaint_mask
697
+
698
+ # Copied from SDGligenPipeline
699
+ def crop(self, im, new_width, new_height):
700
+ width, height = im.size
701
+ left = (width - new_width) / 2
702
+ top = (height - new_height) / 2
703
+ right = (width + new_width) / 2
704
+ bottom = (height + new_height) / 2
705
+ return im.crop((left, top, right, bottom))
706
+
707
+ # Copied from SDGligenPipeline
708
+ def target_size_center_crop(self, im, new_hw):
709
+ width, height = im.size
710
+ if width != height:
711
+ im = self.crop(im, min(height, width), min(height, width))
712
+ return im.resize((new_hw, new_hw), PIL.Image.LANCZOS)
713
+
714
+ # Copied from SDXL
715
+ def _get_add_time_ids(
716
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
717
+ ):
718
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
719
+
720
+ passed_add_embed_dim = (
721
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
722
+ )
723
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
724
+
725
+ if expected_add_embed_dim != passed_add_embed_dim:
726
+ raise ValueError(
727
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
728
+ )
729
+
730
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
731
+ return add_time_ids
732
+
733
+ # Copied from SDXL
734
+ def upcast_vae(self):
735
+ dtype = self.vae.dtype
736
+ self.vae.to(dtype=torch.float32)
737
+ use_torch_2_0_or_xformers = isinstance(
738
+ self.vae.decoder.mid_block.attentions[0].processor,
739
+ (
740
+ AttnProcessor2_0,
741
+ XFormersAttnProcessor,
742
+ FusedAttnProcessor2_0,
743
+ ),
744
+ )
745
+ # if xformers or torch_2_0 is used attention block does not need
746
+ # to be in float32 which can save lots of memory
747
+ if use_torch_2_0_or_xformers:
748
+ self.vae.post_quant_conv.to(dtype)
749
+ self.vae.decoder.conv_in.to(dtype)
750
+ self.vae.decoder.mid_block.to(dtype)
751
+
752
+ # Copied from SDXL
753
+ def get_guidance_scale_embedding(
754
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
755
+ ) -> torch.Tensor:
756
+ """
757
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
758
+
759
+ Args:
760
+ w (`torch.Tensor`):
761
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
762
+ embedding_dim (`int`, *optional*, defaults to 512):
763
+ Dimension of the embeddings to generate.
764
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
765
+ Data type of the generated embeddings.
766
+
767
+ Returns:
768
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
769
+ """
770
+ assert len(w.shape) == 1
771
+ w = w * 1000.0
772
+
773
+ half_dim = embedding_dim // 2
774
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
775
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
776
+ emb = w.to(dtype)[:, None] * emb[None, :]
777
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
778
+ if embedding_dim % 2 == 1: # zero pad
779
+ emb = torch.nn.functional.pad(emb, (0, 1))
780
+ assert emb.shape == (w.shape[0], embedding_dim)
781
+ return emb
782
+
783
+ @property
784
+ def guidance_scale(self):
785
+ return self._guidance_scale
786
+
787
+ @property
788
+ def guidance_rescale(self):
789
+ return self._guidance_rescale
790
+
791
+ @property
792
+ def clip_skip(self):
793
+ return self._clip_skip
794
+
795
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
796
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
797
+ # corresponds to doing no classifier free guidance.
798
+ @property
799
+ def do_classifier_free_guidance(self):
800
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
801
+
802
+ @property
803
+ def cross_attention_kwargs(self):
804
+ return self._cross_attention_kwargs
805
+
806
+ @property
807
+ def denoising_end(self):
808
+ return self._denoising_end
809
+
810
+ @property
811
+ def num_timesteps(self):
812
+ return self._num_timesteps
813
+
814
+ @property
815
+ def interrupt(self):
816
+ return self._interrupt
817
+
818
+ @torch.no_grad()
819
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
820
+ def __call__(
821
+ self,
822
+ prompt: Union[str, List[str]] = None,
823
+ prompt_2: Optional[Union[str, List[str]]] = None,
824
+ height: Optional[int] = None,
825
+ width: Optional[int] = None,
826
+ num_inference_steps: int = 50,
827
+ timesteps: List[int] = None,
828
+ sigmas: List[float] = None,
829
+ denoising_end: Optional[float] = None,
830
+ guidance_scale: float = 5.0,
831
+ interactdiffusion_scheduled_sampling_beta: float = 1.0,
832
+ interactdiffusion_subject_phrases: List[str] = None,
833
+ interactdiffusion_subject_boxes: List[List[float]] = None,
834
+ interactdiffusion_object_phrases: List[str] = None,
835
+ interactdiffusion_object_boxes: List[List[float]] = None,
836
+ interactdiffusion_action_phrases: List[str] = None,
837
+ negative_prompt: Optional[Union[str, List[str]]] = None,
838
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
839
+ num_images_per_prompt: Optional[int] = 1,
840
+ eta: float = 0.0,
841
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
842
+ latents: Optional[torch.Tensor] = None,
843
+ prompt_embeds: Optional[torch.Tensor] = None,
844
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
845
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
846
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
847
+ ip_adapter_image: Optional[PipelineImageInput] = None,
848
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
849
+ output_type: Optional[str] = "pil",
850
+ return_dict: bool = True,
851
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
852
+ guidance_rescale: float = 0.0,
853
+ original_size: Optional[Tuple[int, int]] = None,
854
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
855
+ target_size: Optional[Tuple[int, int]] = None,
856
+ negative_original_size: Optional[Tuple[int, int]] = None,
857
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
858
+ negative_target_size: Optional[Tuple[int, int]] = None,
859
+ clip_skip: Optional[int] = None,
860
+ callback_on_step_end: Optional[
861
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
862
+ ] = None,
863
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
864
+ **kwargs,
865
+ ):
866
+ r"""
867
+ Function invoked when calling the pipeline for generation.
868
+
869
+ Args:
870
+ prompt (`str` or `List[str]`, *optional*):
871
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
872
+ instead.
873
+ prompt_2 (`str` or `List[str]`, *optional*):
874
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
875
+ used in both text-encoders
876
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
877
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
878
+ Anything below 512 pixels won't work well for
879
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
880
+ and checkpoints that are not specifically fine-tuned on low resolutions.
881
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
882
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
883
+ Anything below 512 pixels won't work well for
884
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
885
+ and checkpoints that are not specifically fine-tuned on low resolutions.
886
+ num_inference_steps (`int`, *optional*, defaults to 50):
887
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
888
+ expense of slower inference.
889
+ timesteps (`List[int]`, *optional*):
890
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
891
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
892
+ passed will be used. Must be in descending order.
893
+ sigmas (`List[float]`, *optional*):
894
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
895
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
896
+ will be used.
897
+ denoising_end (`float`, *optional*):
898
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
899
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
900
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
901
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
902
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
903
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
904
+ guidance_scale (`float`, *optional*, defaults to 5.0):
905
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
906
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
907
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
908
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
909
+ usually at the expense of lower image quality.
910
+ interactdiffusion_subject_phrases (`List[str]`):
911
+ The phrases to guide what to include in each of the subject defined by the corresponding
912
+ `interactdiffusion_subject_boxes`. There should only be one phrase per subject bounding box.
913
+ interactdiffusion_subject_boxes (`List[List[float]]`):
914
+ The bounding boxes that identify rectangular regions of the image that are going to be filled with the
915
+ subject described by the corresponding `interactdiffusion_subject_phrases`. Each rectangular box is
916
+ defined as a `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].
917
+ interactdiffusion_object_phrases (`List[str]`):
918
+ The phrases to guide what to include in each of the object defined by the corresponding
919
+ `interactdiffusion_object_boxes`. There should only be one phrase per object bounding box.
920
+ interactdiffusion_object_boxes (`List[List[float]]`):
921
+ The bounding boxes that identify rectangular regions of the image that are going to be filled with the
922
+ object described by the corresponding `interactdiffusion_object_phrases`. Each rectangular box is
923
+ defined as a `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].
924
+ interactdiffusion_action_phrases (`List[str]`):
925
+ The phrases to guide what to include in each of the interaction defined between subject and object bounding boxes.
926
+ There should only be one phrase per subject-object pair.
927
+ interactdiffusion_scheduled_sampling_beta (`float`, defaults to 1.0):
928
+ Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
929
+ Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
930
+ scheduled sampling during inference for improved quality and controllability.
931
+ negative_prompt (`str` or `List[str]`, *optional*):
932
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
933
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
934
+ less than `1`).
935
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
936
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
937
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
938
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
939
+ The number of images to generate per prompt.
940
+ eta (`float`, *optional*, defaults to 0.0):
941
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
942
+ [`schedulers.DDIMScheduler`], will be ignored for others.
943
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
944
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
945
+ to make generation deterministic.
946
+ latents (`torch.Tensor`, *optional*):
947
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
948
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
949
+ tensor will ge generated by sampling using the supplied random `generator`.
950
+ prompt_embeds (`torch.Tensor`, *optional*):
951
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
952
+ provided, text embeddings will be generated from `prompt` input argument.
953
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
954
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
955
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
956
+ argument.
957
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
958
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
959
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
960
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
961
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
962
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
963
+ input argument.
964
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
965
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
966
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
967
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
968
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
969
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
970
+ output_type (`str`, *optional*, defaults to `"pil"`):
971
+ The output format of the generate image. Choose between
972
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
973
+ return_dict (`bool`, *optional*, defaults to `True`):
974
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
975
+ of a plain tuple.
976
+ cross_attention_kwargs (`dict`, *optional*):
977
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
978
+ `self.processor` in
979
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
980
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
981
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
982
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
983
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
984
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
985
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
986
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
987
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
988
+ explained in section 2.2 of
989
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
990
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
991
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
992
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
993
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
994
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
995
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
996
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
997
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
998
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
999
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1000
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1001
+ micro-conditioning as explained in section 2.2 of
1002
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1003
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1004
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1005
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1006
+ micro-conditioning as explained in section 2.2 of
1007
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1008
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1009
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1010
+ To negatively condition the generation process based on a target image resolution. It should be as same
1011
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1012
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1013
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1014
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1015
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1016
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1017
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1018
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1019
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1020
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1021
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1022
+ `._callback_tensor_inputs` attribute of your pipeline class.
1023
+
1024
+ Examples:
1025
+
1026
+ Returns:
1027
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1028
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1029
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1030
+ """
1031
+
1032
+ callback = kwargs.pop("callback", None)
1033
+ callback_steps = kwargs.pop("callback_steps", None)
1034
+
1035
+ if callback is not None:
1036
+ deprecate(
1037
+ "callback",
1038
+ "1.0.0",
1039
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1040
+ )
1041
+ if callback_steps is not None:
1042
+ deprecate(
1043
+ "callback_steps",
1044
+ "1.0.0",
1045
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1046
+ )
1047
+
1048
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1049
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1050
+
1051
+ # 0. Default height and width to unet
1052
+ height = height or self.default_sample_size * self.vae_scale_factor
1053
+ width = width or self.default_sample_size * self.vae_scale_factor
1054
+
1055
+ original_size = original_size or (height, width)
1056
+ target_size = target_size or (height, width)
1057
+
1058
+ # 1. Check inputs. Raise error if not correct
1059
+ self.check_inputs(
1060
+ prompt,
1061
+ prompt_2,
1062
+ height,
1063
+ width,
1064
+ callback_steps,
1065
+ interactdiffusion_subject_phrases,
1066
+ interactdiffusion_subject_boxes,
1067
+ interactdiffusion_object_phrases,
1068
+ interactdiffusion_object_boxes,
1069
+ interactdiffusion_action_phrases,
1070
+ negative_prompt,
1071
+ negative_prompt_2,
1072
+ prompt_embeds,
1073
+ negative_prompt_embeds,
1074
+ pooled_prompt_embeds,
1075
+ negative_pooled_prompt_embeds,
1076
+ ip_adapter_image,
1077
+ ip_adapter_image_embeds,
1078
+ callback_on_step_end_tensor_inputs,
1079
+ )
1080
+
1081
+ self._guidance_scale = guidance_scale
1082
+ self._guidance_rescale = guidance_rescale
1083
+ self._clip_skip = clip_skip
1084
+ self._cross_attention_kwargs = cross_attention_kwargs
1085
+ self._denoising_end = denoising_end
1086
+ self._interrupt = False
1087
+
1088
+ # 2. Define call parameters
1089
+ if prompt is not None and isinstance(prompt, str):
1090
+ batch_size = 1
1091
+ elif prompt is not None and isinstance(prompt, list):
1092
+ batch_size = len(prompt)
1093
+ else:
1094
+ batch_size = prompt_embeds.shape[0]
1095
+
1096
+ device = self._execution_device
1097
+
1098
+ # 3. Encode input prompt
1099
+ lora_scale = (
1100
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1101
+ )
1102
+
1103
+ (
1104
+ prompt_embeds,
1105
+ negative_prompt_embeds,
1106
+ pooled_prompt_embeds,
1107
+ negative_pooled_prompt_embeds,
1108
+ ) = self.encode_prompt(
1109
+ prompt=prompt,
1110
+ prompt_2=prompt_2,
1111
+ device=device,
1112
+ num_images_per_prompt=num_images_per_prompt,
1113
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1114
+ negative_prompt=negative_prompt,
1115
+ negative_prompt_2=negative_prompt_2,
1116
+ prompt_embeds=prompt_embeds,
1117
+ negative_prompt_embeds=negative_prompt_embeds,
1118
+ pooled_prompt_embeds=pooled_prompt_embeds,
1119
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1120
+ lora_scale=lora_scale,
1121
+ clip_skip=self.clip_skip,
1122
+ )
1123
+
1124
+ # 4. Prepare timesteps
1125
+ timesteps, num_inference_steps = retrieve_timesteps(
1126
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1127
+ )
1128
+
1129
+ # 5. Prepare latent variables
1130
+ num_channels_latents = self.unet.config.in_channels
1131
+ latents = self.prepare_latents(
1132
+ batch_size * num_images_per_prompt,
1133
+ num_channels_latents,
1134
+ height,
1135
+ width,
1136
+ prompt_embeds.dtype,
1137
+ device,
1138
+ generator,
1139
+ latents,
1140
+ )
1141
+ # 5.1 Prepare InteractDiffusion variables
1142
+ max_objs = 30
1143
+ if len(interactdiffusion_action_phrases) > max_objs:
1144
+ warnings.warn(
1145
+ f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.",
1146
+ FutureWarning,
1147
+ )
1148
+ interactdiffusion_subject_phrases = interactdiffusion_subject_phrases[:max_objs]
1149
+ interactdiffusion_subject_boxes = interactdiffusion_subject_boxes[:max_objs]
1150
+ interactdiffusion_object_phrases = interactdiffusion_object_phrases[:max_objs]
1151
+ interactdiffusion_object_boxes = interactdiffusion_object_boxes[:max_objs]
1152
+ interactdiffusion_action_phrases = interactdiffusion_action_phrases[:max_objs]
1153
+ # prepare batched input to the InteractDiffusionInteractionProjection (boxes, phrases, mask)
1154
+ # obtain its text features for phrases
1155
+ (
1156
+ interactdiffusion_embeds
1157
+ ) = self.encode_prompt_gligen(
1158
+ prompt=interactdiffusion_subject_phrases+interactdiffusion_object_phrases+interactdiffusion_action_phrases,
1159
+ device=device,
1160
+ num_images_per_prompt=1,
1161
+ # TODO: whether we had to follow prompt encoding configuration on LoRA and CLIP skip
1162
+ lora_scale=lora_scale,
1163
+ clip_skip=self.clip_skip,
1164
+ )
1165
+
1166
+ n_objs = min(len(interactdiffusion_subject_boxes), max_objs)
1167
+ # For each entity, described in phrases, is denoted with a bounding box,
1168
+ # we represent the location information as (xmin,ymin,xmax,ymax)
1169
+ # boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
1170
+ # boxes[:n_objs] = torch.tensor(gligen_boxes)
1171
+ subject_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
1172
+ object_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
1173
+ subject_boxes[:n_objs] = torch.tensor(interactdiffusion_subject_boxes[:n_objs])
1174
+ object_boxes[:n_objs] = torch.tensor(interactdiffusion_object_boxes[:n_objs])
1175
+
1176
+ text_embeddings = torch.zeros(
1177
+ max_objs*3, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
1178
+ )
1179
+ text_embeddings[:n_objs*3] = interactdiffusion_embeds
1180
+
1181
+ subject_text_embeddings = torch.zeros(max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype)
1182
+ subject_text_embeddings[:n_objs] = text_embeddings[:n_objs*1]
1183
+ object_text_embeddings = torch.zeros(max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype)
1184
+ object_text_embeddings[:n_objs] = text_embeddings[n_objs*1:n_objs*2]
1185
+ action_text_embeddings = torch.zeros(max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype)
1186
+ action_text_embeddings[:n_objs] = text_embeddings[n_objs*2:n_objs*3]
1187
+ # Generate a mask for each object that is entity described by phrases
1188
+ masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype)
1189
+ masks[:n_objs] = 1
1190
+ repeat_batch = batch_size * num_images_per_prompt
1191
+ # boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1192
+ # text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1193
+ subject_boxes = subject_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1194
+ object_boxes = object_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1195
+ subject_text_embeddings = subject_text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1196
+ object_text_embeddings = object_text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1197
+ action_text_embeddings = action_text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1198
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
1199
+ if self.do_classifier_free_guidance:
1200
+ repeat_batch = repeat_batch * 2
1201
+ # boxes = torch.cat([boxes] * 2)
1202
+ # text_embeddings = torch.cat([text_embeddings] * 2)
1203
+ subject_boxes = torch.cat([subject_boxes] * 2)
1204
+ object_boxes = torch.cat([object_boxes] * 2)
1205
+ subject_text_embeddings = torch.cat([subject_text_embeddings] * 2)
1206
+ object_text_embeddings = torch.cat([object_text_embeddings] * 2)
1207
+ action_text_embeddings = torch.cat([action_text_embeddings] * 2)
1208
+ masks = torch.cat([masks] * 2)
1209
+ masks[: repeat_batch // 2] = 0
1210
+ if self.cross_attention_kwargs is None:
1211
+ self._cross_attention_kwargs = {}
1212
+ # self.cross_attention_kwargs["gligen"] = {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks}
1213
+ self.cross_attention_kwargs['gligen'] = {
1214
+ 'subject_boxes': subject_boxes,
1215
+ 'object_boxes': object_boxes,
1216
+ 'subject_positive_embeddings': subject_text_embeddings,
1217
+ 'object_positive_embeddings': object_text_embeddings,
1218
+ 'action_positive_embeddings': action_text_embeddings,
1219
+ 'masks': masks
1220
+ }
1221
+
1222
+ num_grounding_steps = int(interactdiffusion_scheduled_sampling_beta * len(timesteps))
1223
+ self.enable_fuser(True)
1224
+
1225
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1226
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1227
+
1228
+ # 7. Prepare added time ids & embeddings
1229
+ add_text_embeds = pooled_prompt_embeds
1230
+ if self.text_encoder_2 is None:
1231
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1232
+ else:
1233
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1234
+
1235
+ add_time_ids = self._get_add_time_ids(
1236
+ original_size,
1237
+ crops_coords_top_left,
1238
+ target_size,
1239
+ dtype=prompt_embeds.dtype,
1240
+ text_encoder_projection_dim=text_encoder_projection_dim,
1241
+ )
1242
+ if negative_original_size is not None and negative_target_size is not None:
1243
+ negative_add_time_ids = self._get_add_time_ids(
1244
+ negative_original_size,
1245
+ negative_crops_coords_top_left,
1246
+ negative_target_size,
1247
+ dtype=prompt_embeds.dtype,
1248
+ text_encoder_projection_dim=text_encoder_projection_dim,
1249
+ )
1250
+ else:
1251
+ negative_add_time_ids = add_time_ids
1252
+
1253
+ if self.do_classifier_free_guidance:
1254
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1255
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1256
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1257
+
1258
+ prompt_embeds = prompt_embeds.to(device)
1259
+ add_text_embeds = add_text_embeds.to(device)
1260
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1261
+
1262
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1263
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1264
+ ip_adapter_image,
1265
+ ip_adapter_image_embeds,
1266
+ device,
1267
+ batch_size * num_images_per_prompt,
1268
+ self.do_classifier_free_guidance,
1269
+ )
1270
+
1271
+ # 8. Denoising loop
1272
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1273
+
1274
+ # 8.1 Apply denoising_end
1275
+ if (
1276
+ self.denoising_end is not None
1277
+ and isinstance(self.denoising_end, float)
1278
+ and self.denoising_end > 0
1279
+ and self.denoising_end < 1
1280
+ ):
1281
+ discrete_timestep_cutoff = int(
1282
+ round(
1283
+ self.scheduler.config.num_train_timesteps
1284
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1285
+ )
1286
+ )
1287
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1288
+ timesteps = timesteps[:num_inference_steps]
1289
+
1290
+ # 9. Optionally get Guidance Scale Embedding
1291
+ timestep_cond = None
1292
+ if self.unet.config.time_cond_proj_dim is not None:
1293
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1294
+ timestep_cond = self.get_guidance_scale_embedding(
1295
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1296
+ ).to(device=device, dtype=latents.dtype)
1297
+
1298
+ self._num_timesteps = len(timesteps)
1299
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1300
+ for i, t in enumerate(timesteps):
1301
+ if self.interrupt:
1302
+ continue
1303
+
1304
+ if i == num_grounding_steps:
1305
+ self.enable_fuser(False)
1306
+
1307
+ # expand the latents if we are doing classifier free guidance
1308
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1309
+
1310
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1311
+
1312
+ # predict the noise residual
1313
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1314
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1315
+ added_cond_kwargs["image_embeds"] = image_embeds
1316
+ noise_pred = self.unet(
1317
+ latent_model_input,
1318
+ t,
1319
+ encoder_hidden_states=prompt_embeds,
1320
+ timestep_cond=timestep_cond,
1321
+ cross_attention_kwargs=self.cross_attention_kwargs,
1322
+ added_cond_kwargs=added_cond_kwargs,
1323
+ return_dict=False,
1324
+ )[0]
1325
+
1326
+ # perform guidance
1327
+ if self.do_classifier_free_guidance:
1328
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1329
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1330
+
1331
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1332
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1333
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1334
+
1335
+ # compute the previous noisy sample x_t -> x_t-1
1336
+ latents_dtype = latents.dtype
1337
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1338
+ if latents.dtype != latents_dtype:
1339
+ if torch.backends.mps.is_available():
1340
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1341
+ latents = latents.to(latents_dtype)
1342
+
1343
+ if callback_on_step_end is not None:
1344
+ callback_kwargs = {}
1345
+ for k in callback_on_step_end_tensor_inputs:
1346
+ callback_kwargs[k] = locals()[k]
1347
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1348
+
1349
+ latents = callback_outputs.pop("latents", latents)
1350
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1351
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1352
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1353
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1354
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1355
+ )
1356
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1357
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1358
+
1359
+ # call the callback, if provided
1360
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1361
+ progress_bar.update()
1362
+ if callback is not None and i % callback_steps == 0:
1363
+ step_idx = i // getattr(self.scheduler, "order", 1)
1364
+ callback(step_idx, t, latents)
1365
+
1366
+ if XLA_AVAILABLE:
1367
+ xm.mark_step()
1368
+
1369
+ if not output_type == "latent":
1370
+ # make sure the VAE is in float32 mode, as it overflows in float16
1371
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1372
+
1373
+ if needs_upcasting:
1374
+ self.upcast_vae()
1375
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1376
+ elif latents.dtype != self.vae.dtype:
1377
+ if torch.backends.mps.is_available():
1378
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1379
+ self.vae = self.vae.to(latents.dtype)
1380
+
1381
+ # unscale/denormalize the latents
1382
+ # denormalize with the mean and std if available and not None
1383
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1384
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1385
+ if has_latents_mean and has_latents_std:
1386
+ latents_mean = (
1387
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1388
+ )
1389
+ latents_std = (
1390
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1391
+ )
1392
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1393
+ else:
1394
+ latents = latents / self.vae.config.scaling_factor
1395
+
1396
+ image = self.vae.decode(latents, return_dict=False)[0]
1397
+
1398
+ # cast back to fp16 if needed
1399
+ if needs_upcasting:
1400
+ self.vae.to(dtype=torch.float16)
1401
+ else:
1402
+ image = latents
1403
+
1404
+ if not output_type == "latent":
1405
+ # apply watermark if available
1406
+ if self.watermark is not None:
1407
+ image = self.watermark.apply_watermark(image)
1408
+
1409
+ image = self.image_processor.postprocess(image, output_type=output_type)
1410
+
1411
+ # Offload all models
1412
+ self.maybe_free_model_hooks()
1413
+
1414
+ if not return_dict:
1415
+ return (image,)
1416
+
1417
+ return StableDiffusionXLPipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.30.1",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "final_sigmas_type": "zero",
9
+ "interpolation_type": "linear",
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "sigma_max": null,
16
+ "sigma_min": null,
17
+ "skip_prk_steps": true,
18
+ "steps_offset": 1,
19
+ "timestep_spacing": "leading",
20
+ "timestep_type": "discrete",
21
+ "trained_betas": null,
22
+ "use_karras_sigmas": false
23
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/user/jiuntian/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/text_encoder",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.44.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660c6f5b1abae9dc498ac2d21e1347d2abdb0cf6c0c0c8576cd796491d9a6cdd
3
+ size 246144152
text_encoder_2/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/user/jiuntian/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/text_encoder_2",
3
+ "architectures": [
4
+ "CLIPTextModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1280,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5120,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 20,
19
+ "num_hidden_layers": 32,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 1280,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.44.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder_2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec310df2af79c318e24d20511b601a591ca8cd4f1fce1d8dff822a356bcdb1f4
3
+ size 1389382176
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer_2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "logs/gligen_sdxl_bs32/checkpoint-256000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "attention_type": "gated",
15
+ "block_out_channels": [
16
+ 320,
17
+ 640,
18
+ 1280
19
+ ],
20
+ "center_input_sample": false,
21
+ "class_embed_type": null,
22
+ "class_embeddings_concat": false,
23
+ "conv_in_kernel": 3,
24
+ "conv_out_kernel": 3,
25
+ "cross_attention_dim": 2048,
26
+ "cross_attention_norm": null,
27
+ "down_block_types": [
28
+ "DownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "CrossAttnDownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "dropout": 0.0,
34
+ "dual_cross_attention": false,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_only_cross_attention": null,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": null,
47
+ "num_class_embeds": null,
48
+ "only_cross_attention": false,
49
+ "out_channels": 4,
50
+ "projection_class_embeddings_input_dim": 2816,
51
+ "resnet_out_scale_factor": 1.0,
52
+ "resnet_skip_time_act": false,
53
+ "resnet_time_scale_shift": "default",
54
+ "reverse_transformer_layers_per_block": null,
55
+ "sample_size": 128,
56
+ "time_cond_proj_dim": null,
57
+ "time_embedding_act_fn": null,
58
+ "time_embedding_dim": null,
59
+ "time_embedding_type": "positional",
60
+ "timestep_post_act": null,
61
+ "transformer_layers_per_block": [
62
+ 1,
63
+ 2,
64
+ 10
65
+ ],
66
+ "up_block_types": [
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D",
69
+ "UpBlock2D"
70
+ ],
71
+ "upcast_attention": null,
72
+ "use_linear_projection": true
73
+ }
unet/diffusion_pytorch_model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6133670cdd8efaa06cda036435e70101f9ebb389bbaab27e273442791a30b9
3
+ size 9995914112
unet/diffusion_pytorch_model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77c266382f7abd1678b402fbcd8e236ea8c50a60babb39960f57468222377553
3
+ size 7524087696
unet/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/interactdiffusion_unet_2d_condition.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.models.embeddings import get_fourier_embeds_from_boundingbox
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class AbsolutePositionalEmbedding(nn.Module):
8
+ def __init__(self, dim, max_seq_len):
9
+ super().__init__()
10
+ self.emb = nn.Embedding(max_seq_len, dim)
11
+ self.init_()
12
+
13
+ def init_(self):
14
+ nn.init.normal_(self.emb.weight, std=0.02)
15
+
16
+ def forward(self, x):
17
+ n = torch.arange(x.shape[1], device=x.device)
18
+ return self.emb(n)[None, :, :]
19
+
20
+
21
+ class InteractDiffusionInteractionProjection(nn.Module):
22
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
23
+ super().__init__()
24
+ self.in_dim = in_dim
25
+ self.out_dim = out_dim
26
+
27
+ self.fourier_embedder_dim = fourier_freqs
28
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
29
+ self.interaction_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=32)
30
+ self.position_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=3)
31
+
32
+ if isinstance(out_dim, tuple):
33
+ out_dim = out_dim[0]
34
+
35
+ self.linears = nn.Sequential(
36
+ nn.Linear(self.in_dim + self.position_dim, 512),
37
+ nn.SiLU(),
38
+ nn.Linear(512, 512),
39
+ nn.SiLU(),
40
+ nn.Linear(512, out_dim),
41
+ )
42
+
43
+ self.linear_action = nn.Sequential(
44
+ nn.Linear(self.in_dim + self.position_dim, 512),
45
+ nn.SiLU(),
46
+ nn.Linear(512, 512),
47
+ nn.SiLU(),
48
+ nn.Linear(512, out_dim),
49
+ )
50
+
51
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
52
+ self.null_action_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
53
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
54
+
55
+ def get_between_box(self, bbox1, bbox2):
56
+ """ Between Set Operation
57
+ Operation of Box A between Box B from Prof. Jiang idea
58
+ """
59
+ all_x = torch.cat([bbox1[:, :, 0::2], bbox2[:, :, 0::2]],dim=-1)
60
+ all_y = torch.cat([bbox1[:, :, 1::2], bbox2[:, :, 1::2]],dim=-1)
61
+ all_x, _ = all_x.sort()
62
+ all_y, _ = all_y.sort()
63
+ return torch.stack([all_x[:,:,1], all_y[:,:,1], all_x[:,:,2], all_y[:,:,2]],2)
64
+
65
+ def forward(
66
+ self,
67
+ subject_boxes, object_boxes,
68
+ masks,
69
+ subject_positive_embeddings, object_positive_embeddings, action_positive_embeddings
70
+ ):
71
+ masks = masks.unsqueeze(-1)
72
+
73
+ # embedding position (it may include padding as placeholder)
74
+ action_boxes = self.get_between_box(subject_boxes, object_boxes)
75
+ subject_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, subject_boxes) # B*N*4 --> B*N*C
76
+ object_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, object_boxes) # B*N*4 --> B*N*C
77
+ action_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, action_boxes) # B*N*4 --> B*N*C
78
+
79
+ # learnable null embedding
80
+ positive_null = self.null_positive_feature.view(1, 1, -1)
81
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
82
+ action_null = self.null_action_feature.view(1, 1, -1)
83
+
84
+ # replace padding with learnable null embedding
85
+ subject_positive_embeddings = subject_positive_embeddings * masks + (1 - masks) * positive_null
86
+ object_positive_embeddings = object_positive_embeddings * masks + (1 - masks) * positive_null
87
+
88
+ subject_xyxy_embedding = subject_xyxy_embedding * masks + (1 - masks) * xyxy_null
89
+ object_xyxy_embedding = object_xyxy_embedding * masks + (1 - masks) * xyxy_null
90
+ action_xyxy_embedding = action_xyxy_embedding * masks + (1 - masks) * xyxy_null
91
+
92
+ action_positive_embeddings = action_positive_embeddings * masks + (1 - masks) * action_null
93
+
94
+ # project the input embeddings
95
+ objs_subject = self.linears(torch.cat([subject_positive_embeddings, subject_xyxy_embedding], dim=-1))
96
+ objs_object = self.linears(torch.cat([object_positive_embeddings, object_xyxy_embedding], dim=-1))
97
+ objs_action = self.linear_action(torch.cat([action_positive_embeddings, action_xyxy_embedding], dim=-1))
98
+
99
+ # impose role embedding
100
+ objs_subject = objs_subject + self.interaction_embedding(objs_subject)
101
+ objs_object = objs_object + self.interaction_embedding(objs_object)
102
+ objs_action = objs_action + self.interaction_embedding(objs_action)
103
+
104
+ # impose instance embedding
105
+ objs_subject = objs_subject + self.position_embedding.emb(torch.tensor(0).to(objs_subject.device))
106
+ objs_object = objs_object + self.position_embedding.emb(torch.tensor(1).to(objs_object.device))
107
+ objs_action = objs_action + self.position_embedding.emb(torch.tensor(2).to(objs_action.device))
108
+
109
+ objs = torch.cat([objs_subject, objs_action, objs_object], dim=1)
110
+
111
+ return objs
112
+
113
+
114
+ class InteractDiffusionUNet2DConditionModel(UNet2DConditionModel):
115
+ def __init__(self,
116
+ sample_size: Optional[int] = None,
117
+ in_channels: int = 4,
118
+ out_channels: int = 4,
119
+ center_input_sample: bool = False,
120
+ flip_sin_to_cos: bool = True,
121
+ freq_shift: int = 0,
122
+ down_block_types: Tuple[str] = (
123
+ "CrossAttnDownBlock2D",
124
+ "CrossAttnDownBlock2D",
125
+ "CrossAttnDownBlock2D",
126
+ "DownBlock2D",
127
+ ),
128
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
129
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
130
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
131
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
132
+ layers_per_block: Union[int, Tuple[int]] = 2,
133
+ downsample_padding: int = 1,
134
+ mid_block_scale_factor: float = 1,
135
+ dropout: float = 0.0,
136
+ act_fn: str = "silu",
137
+ norm_num_groups: Optional[int] = 32,
138
+ norm_eps: float = 1e-5,
139
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
140
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
141
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
142
+ encoder_hid_dim: Optional[int] = None,
143
+ encoder_hid_dim_type: Optional[str] = None,
144
+ attention_head_dim: Union[int, Tuple[int]] = 8,
145
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
146
+ dual_cross_attention: bool = False,
147
+ use_linear_projection: bool = False,
148
+ class_embed_type: Optional[str] = None,
149
+ addition_embed_type: Optional[str] = None,
150
+ addition_time_embed_dim: Optional[int] = None,
151
+ num_class_embeds: Optional[int] = None,
152
+ upcast_attention: bool = False,
153
+ resnet_time_scale_shift: str = "default",
154
+ resnet_skip_time_act: bool = False,
155
+ resnet_out_scale_factor: float = 1.0,
156
+ time_embedding_type: str = "positional",
157
+ time_embedding_dim: Optional[int] = None,
158
+ time_embedding_act_fn: Optional[str] = None,
159
+ timestep_post_act: Optional[str] = None,
160
+ time_cond_proj_dim: Optional[int] = None,
161
+ conv_in_kernel: int = 3,
162
+ conv_out_kernel: int = 3,
163
+ projection_class_embeddings_input_dim: Optional[int] = None,
164
+ attention_type: str = "default",
165
+ class_embeddings_concat: bool = False,
166
+ mid_block_only_cross_attention: Optional[bool] = None,
167
+ cross_attention_norm: Optional[str] = None,
168
+ addition_embed_type_num_heads: int = 64,
169
+ ):
170
+ super(InteractDiffusionUNet2DConditionModel, self).__init__(
171
+ sample_size=sample_size,
172
+ in_channels=in_channels,
173
+ out_channels=out_channels,
174
+ center_input_sample=center_input_sample,
175
+ flip_sin_to_cos=flip_sin_to_cos,
176
+ freq_shift=freq_shift,
177
+ down_block_types=down_block_types,
178
+ mid_block_type=mid_block_type,
179
+ up_block_types=up_block_types,
180
+ only_cross_attention=only_cross_attention,
181
+ block_out_channels=block_out_channels,
182
+ layers_per_block=layers_per_block,
183
+ downsample_padding=downsample_padding,
184
+ mid_block_scale_factor=mid_block_scale_factor,
185
+ dropout=dropout,
186
+ act_fn=act_fn,
187
+ norm_num_groups=norm_num_groups,
188
+ norm_eps=norm_eps,
189
+ cross_attention_dim=cross_attention_dim,
190
+ transformer_layers_per_block=transformer_layers_per_block,
191
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
192
+ encoder_hid_dim=encoder_hid_dim,
193
+ encoder_hid_dim_type=encoder_hid_dim_type,
194
+ attention_head_dim=attention_head_dim,
195
+ num_attention_heads=num_attention_heads,
196
+ dual_cross_attention=dual_cross_attention,
197
+ use_linear_projection=use_linear_projection,
198
+ class_embed_type=class_embed_type,
199
+ addition_embed_type=addition_embed_type,
200
+ addition_time_embed_dim=addition_time_embed_dim,
201
+ num_class_embeds=num_class_embeds,
202
+ upcast_attention=upcast_attention,
203
+ resnet_time_scale_shift=resnet_time_scale_shift,
204
+ resnet_skip_time_act=resnet_skip_time_act,
205
+ resnet_out_scale_factor=resnet_out_scale_factor,
206
+ time_embedding_type=time_embedding_type,
207
+ time_embedding_dim=time_embedding_dim,
208
+ time_embedding_act_fn=time_embedding_act_fn,
209
+ timestep_post_act=timestep_post_act,
210
+ time_cond_proj_dim=time_cond_proj_dim,
211
+ conv_in_kernel=conv_in_kernel,
212
+ conv_out_kernel=conv_out_kernel,
213
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
214
+ attention_type=attention_type,
215
+ class_embeddings_concat=class_embeddings_concat,
216
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
217
+ cross_attention_norm=cross_attention_norm,
218
+ addition_embed_type_num_heads=addition_embed_type_num_heads
219
+ )
220
+
221
+ # load position_net
222
+ positive_len = 768
223
+ if isinstance(self.config.cross_attention_dim, int):
224
+ positive_len = self.config.cross_attention_dim
225
+ elif isinstance(self.config.cross_attention_dim, tuple) or isinstance(self.config.cross_attention_dim, list):
226
+ positive_len = self.config.cross_attention_dim[0]
227
+
228
+ self.position_net = InteractDiffusionInteractionProjection(
229
+ in_dim=positive_len, out_dim=self.config.cross_attention_dim
230
+ )
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "madebyollin/sdxl-vae-fp16-fix",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.13025,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6353737672c94b96174cb590f711eac6edf2fcce5b6e91aa9d73c5adc589ee48
3
+ size 167335342