tri86uit commited on
Commit
8d4cbac
·
verified ·
1 Parent(s): b5ae46d

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/test_cases/rc_car/02.jpg filter=lfs diff=lfs merge=lfs -text
37
+ imgs/test_cases/rc_car/03.jpg filter=lfs diff=lfs merge=lfs -text
38
+ imgs/test_cases/rc_car/04.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,20 @@
1
  ---
2
- title: SynCD Base FLUX
3
- emoji: 🌖
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: SynCD
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ tags:
12
+ - dwpose
13
+ - pose
14
+ - Text-to-Image
15
+ - Image-to-Image
16
+ - language models
17
+ - LLMs
18
+ short_description: Image generator/customization/personalization
19
  ---
20
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+ from einops import rearrange
9
+ from huggingface_hub import login
10
+ from peft import LoraConfig
11
+ from PIL import Image
12
+ from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline
13
+ from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking
14
+
15
+ HF_TOKEN = os.getenv('HF_TOKEN')
16
+ login(token=HF_TOKEN)
17
+ torch_dtype = torch.bfloat16
18
+ transformer = FluxTransformer2DModelWithMasking.from_pretrained(
19
+ 'black-forest-labs/FLUX.1-dev',
20
+ subfolder='transformer',
21
+ torch_dtype=torch_dtype
22
+ )
23
+ pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype)
24
+ for name, attn_proc in pipeline.transformer.attn_processors.items():
25
+ attn_proc.name = name
26
+
27
+ target_modules=[
28
+ "to_k",
29
+ "to_q",
30
+ "to_v",
31
+ "add_k_proj",
32
+ "add_q_proj",
33
+ "add_v_proj",
34
+ "to_out.0",
35
+ "to_add_out",
36
+ "ff.net.0.proj",
37
+ "ff.net.2",
38
+ "ff_context.net.0.proj",
39
+ "ff_context.net.2",
40
+ "proj_mlp",
41
+ "proj_out",
42
+ ]
43
+ lora_rank = 32
44
+ lora_config = LoraConfig(
45
+ r=lora_rank,
46
+ lora_alpha=lora_rank,
47
+ init_lora_weights="gaussian",
48
+ target_modules=target_modules,
49
+ )
50
+ pipeline.transformer.add_adapter(lora_config)
51
+ finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu')
52
+ transformer_dict = {}
53
+ for key,value in finetuned_path.items():
54
+ if 'transformer.base_model.model.' in key:
55
+ transformer_dict[key.replace('transformer.base_model.model.', '')] = value
56
+ pipeline.transformer.load_state_dict(transformer_dict, strict=False)
57
+ pipeline.to('cuda')
58
+ pipeline.enable_vae_slicing()
59
+ pipeline.enable_vae_tiling()
60
+
61
+ @torch.no_grad()
62
+ def decode(latents, pipeline):
63
+ latents = latents / pipeline.vae.config.scaling_factor
64
+ image = pipeline.vae.decode(latents, return_dict=False)[0]
65
+ return image
66
+
67
+
68
+ @torch.no_grad()
69
+ def encode_target_images(images, pipeline):
70
+ latents = pipeline.vae.encode(images).latent_dist.sample()
71
+ latents = latents * pipeline.vae.config.scaling_factor
72
+ return latents
73
+
74
+
75
+ @spaces.GPU(duration=120)
76
+ def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, enable_cpu_offload=False, neg_prompt="", true_cfg=1.0, image_cfg=0.0):
77
+ if neg_prompt == "":
78
+ neg_prompt = "3d render, cartoon, low resolution, illustration, blurry, unrealistic"
79
+ if enable_cpu_offload:
80
+ pipeline.enable_sequential_cpu_offload()
81
+ input_images = [img1, img2, img3]
82
+ # Delete None
83
+ input_images = [img for img in input_images if img is not None]
84
+ if len(input_images) == 0:
85
+ return "Please upload at least one image"
86
+ numref = len(input_images) + 1
87
+ images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images])
88
+ images = images.to(pipeline.device)
89
+ latents = encode_target_images(images, pipeline)
90
+ latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0)
91
+ masklatent = torch.zeros_like(latents)
92
+ masklatent[:1] = 1.
93
+ latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref)
94
+ masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref)
95
+ B, C, H, W = latents.shape
96
+ latents = pipeline._pack_latents(latents, B, C, H, W)
97
+ masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W)
98
+ output = pipeline(
99
+ text,
100
+ latents_ref=latents,
101
+ latents_mask=masklatent,
102
+ guidance_scale=guidance_scale,
103
+ num_inference_steps=inference_steps,
104
+ height=512,
105
+ width=numref * 512,
106
+ generator = torch.Generator(device="cuda").manual_seed(seed),
107
+ joint_attention_kwargs={'shared_attn': True, 'num': numref},
108
+ return_dict=False,
109
+ negative_prompt=neg_prompt,
110
+ true_cfg_scale=true_cfg,
111
+ image_cfg_scale=image_cfg,
112
+ )[0][0]
113
+ output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref]
114
+ img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) )
115
+ return img
116
+
117
+
118
+
119
+ def get_example():
120
+ case = [
121
+ [
122
+ "An action figure on top of a mountain. Sunset in the background. Realistic shot.",
123
+ "./imgs/test_cases/action_figure/0.jpg",
124
+ "./imgs/test_cases/action_figure/1.jpg",
125
+ "./imgs/test_cases/action_figure/2.jpg",
126
+ 3.5,
127
+ 42,
128
+ False,
129
+ "",
130
+ 1.0,
131
+ 0.0,
132
+ ],
133
+ [
134
+ "A penguin plushie wearing pink sunglasses is lounging on a beach. Realistic shot.",
135
+ "./imgs/test_cases/penguin/0.jpg",
136
+ "./imgs/test_cases/penguin/1.jpg",
137
+ "./imgs/test_cases/penguin/2.jpg",
138
+ 3.5,
139
+ 42,
140
+ False,
141
+ "",
142
+ 1.0,
143
+ 0.0,
144
+ ],
145
+ [
146
+ "A toy on a beach. Waves in the background. Realistic shot.",
147
+ "./imgs/test_cases/rc_car/02.jpg",
148
+ "./imgs/test_cases/rc_car/03.jpg",
149
+ "./imgs/test_cases/rc_car/04.jpg",
150
+ 3.5,
151
+ 42,
152
+ False,
153
+ "",
154
+ 1.0,
155
+ 0.0,
156
+ ],
157
+ ]
158
+ return case
159
+
160
+ def run_for_examples(text, img1, img2, img3, guidance_scale, seed, enable_cpu_offload=False, neg_prompt="", true_cfg=1.0, image_cfg=0.0):
161
+ inference_steps = 30
162
+
163
+ return generate_image(
164
+ text, img1, img2, img3, guidance_scale, inference_steps, seed, enable_cpu_offload, neg_prompt, true_cfg, image_cfg
165
+ )
166
+
167
+ description = """
168
+ Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd).
169
+
170
+ Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image.
171
+
172
+ **HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
173
+ """
174
+
175
+ article = """
176
+ ---
177
+ **Citation**
178
+ <br>
179
+ If you find this repository useful, please consider giving a star ⭐ and a citation
180
+ ```
181
+ @article{kumari2025syncd,
182
+ title={Generating Multi-Image Synthetic Data for Text-to-Image Customization},
183
+ author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh},
184
+ journal={ArXiv},
185
+ year={2025}
186
+ }
187
+ ```
188
+ **Contact**
189
+ <br>
190
+ If you have any questions, please feel free to open an issue or directly reach us out via email.
191
+
192
+ **Acknowledgement**
193
+ <br>
194
+ This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space.
195
+ """
196
+
197
+
198
+ # Gradio
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]")
201
+ gr.Markdown(description)
202
+ with gr.Row():
203
+ with gr.Column():
204
+ # text prompt
205
+ prompt_input = gr.Textbox(
206
+ label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..."
207
+ )
208
+
209
+ with gr.Row(equal_height=True):
210
+ # input images
211
+ image_input_1 = gr.Image(label="img1", type="filepath")
212
+ image_input_2 = gr.Image(label="img2", type="filepath")
213
+ image_input_3 = gr.Image(label="img3", type="filepath")
214
+
215
+ guidance_scale_input = gr.Slider(
216
+ label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1
217
+ )
218
+
219
+ num_inference_steps = gr.Slider(
220
+ label="Inference Steps", minimum=1, maximum=100, value=30, step=1
221
+ )
222
+
223
+ seed_input = gr.Slider(
224
+ label="Seed", minimum=0, maximum=2147483647, value=42, step=1
225
+ )
226
+
227
+ enable_cpu_offload = gr.Checkbox(
228
+ label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False,
229
+ )
230
+
231
+ with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG", open=False): # noqa E501
232
+ neg_prompt = gr.Textbox(
233
+ label="Negative Prompt",
234
+ value="")
235
+ true_cfg = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="true CFG. Recommended to be 1.5")
236
+ image_cfg = gr.Slider(0.0, 10.0, 0.0, step=0.1, label="image CFG scale, will increase the image alignment but longer run time and lower text alignment. Recommended to be 1.0")
237
+
238
+ # generate
239
+ generate_button = gr.Button("Generate Image")
240
+
241
+
242
+ with gr.Column():
243
+ # output image
244
+ output_image = gr.Image(label="Output Image")
245
+
246
+ # click
247
+ generate_button.click(
248
+ generate_image,
249
+ inputs=[
250
+ prompt_input,
251
+ image_input_1,
252
+ image_input_2,
253
+ image_input_3,
254
+ guidance_scale_input,
255
+ num_inference_steps,
256
+ seed_input,
257
+ enable_cpu_offload,
258
+ neg_prompt,
259
+ true_cfg,
260
+ image_cfg,
261
+ ],
262
+ outputs=output_image,
263
+ )
264
+
265
+ gr.Examples(
266
+ examples=get_example(),
267
+ fn=run_for_examples,
268
+ inputs=[
269
+ prompt_input,
270
+ image_input_1,
271
+ image_input_2,
272
+ image_input_3,
273
+ guidance_scale_input,
274
+ seed_input,
275
+ enable_cpu_offload,
276
+ neg_prompt,
277
+ true_cfg,
278
+ image_cfg,
279
+ ],
280
+ outputs=output_image,
281
+ )
282
+
283
+ gr.Markdown(article)
284
+
285
+ # launch
286
+ demo.launch(ssr_mode=False)
287
+
imgs/test_cases/action_figure/0.jpg ADDED
imgs/test_cases/action_figure/1.jpg ADDED
imgs/test_cases/action_figure/2.jpg ADDED
imgs/test_cases/penguin/0.jpg ADDED
imgs/test_cases/penguin/1.jpg ADDED
imgs/test_cases/penguin/2.jpg ADDED
imgs/test_cases/rc_car/02.jpg ADDED

Git LFS Details

  • SHA256: 5306db6234529b9d6d41357724e901791e8f20024d9acee1582b250c599e035b
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
imgs/test_cases/rc_car/03.jpg ADDED

Git LFS Details

  • SHA256: ed3f563a4ec793e34be5ce3810d8c8da3e42c5d4f2fa3e3f4159f6b05ab4307a
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
imgs/test_cases/rc_car/04.jpg ADDED

Git LFS Details

  • SHA256: 46376e08104f686383f04c15e2a5d8341c02afefd2e544bd043d13d7cbc26bc8
  • Pointer size: 131 Bytes
  • Size of remote file: 407 kB
models/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0de7be527b2bf604f679a8c4a0545af4a371e6559aff8bfa28f2a47510872da9
3
+ size 134
pipelines/flux_pipeline/pipeline.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from diffusers import FluxPipeline
21
+ from diffusers.models.autoencoders import AutoencoderKL
22
+ from diffusers.models.transformers import FluxTransformer2DModel
23
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
24
+ from diffusers.utils import is_torch_xla_available
25
+ from transformers import (
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTokenizer,
29
+ CLIPVisionModelWithProjection,
30
+ T5EncoderModel,
31
+ T5TokenizerFast,
32
+ )
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+
42
+ def calculate_shift(
43
+ image_seq_len,
44
+ base_seq_len: int = 256,
45
+ max_seq_len: int = 4096,
46
+ base_shift: float = 0.5,
47
+ max_shift: float = 1.16,
48
+ ):
49
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
50
+ b = base_shift - m * base_seq_len
51
+ mu = image_seq_len * m + b
52
+ return mu
53
+
54
+
55
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
56
+ def retrieve_timesteps(
57
+ scheduler,
58
+ num_inference_steps: Optional[int] = None,
59
+ device: Optional[Union[str, torch.device]] = None,
60
+ timesteps: Optional[List[int]] = None,
61
+ sigmas: Optional[List[float]] = None,
62
+ **kwargs,):
63
+ if timesteps is not None and sigmas is not None:
64
+ raise ValueError(
65
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
66
+ )
67
+ if timesteps is not None:
68
+ accepts_timesteps = "timesteps" in set(
69
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
70
+ )
71
+ if not accepts_timesteps:
72
+ raise ValueError(
73
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
74
+ f" timestep schedules. Please check whether you are using the correct scheduler."
75
+ )
76
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
77
+ timesteps = scheduler.timesteps
78
+ num_inference_steps = len(timesteps)
79
+ elif sigmas is not None:
80
+ accept_sigmas = "sigmas" in set(
81
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
82
+ )
83
+ if not accept_sigmas:
84
+ raise ValueError(
85
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
86
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
87
+ )
88
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ num_inference_steps = len(timesteps)
91
+ else:
92
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ return timesteps, num_inference_steps
95
+
96
+
97
+ def normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale):
98
+ diff_img = image_noise_pred - neg_noise_pred
99
+ diff_txt = noise_pred - image_noise_pred
100
+
101
+ diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
102
+ diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
103
+ min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
104
+ diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
105
+ diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
106
+ pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale * diff_txt
107
+ return pred_guided
108
+
109
+
110
+ class SynCDFluxPipeline(FluxPipeline):
111
+
112
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
113
+ _optional_components = []
114
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
115
+
116
+ def __init__(
117
+ self,
118
+ scheduler: FlowMatchEulerDiscreteScheduler,
119
+ vae: AutoencoderKL,
120
+ text_encoder: CLIPTextModel,
121
+ tokenizer: CLIPTokenizer,
122
+ text_encoder_2: T5EncoderModel,
123
+ tokenizer_2: T5TokenizerFast,
124
+ transformer: FluxTransformer2DModel,
125
+ image_encoder: CLIPVisionModelWithProjection = None,
126
+ feature_extractor: CLIPImageProcessor = None,
127
+ ###
128
+ num=2,
129
+ ):
130
+ super().__init__(
131
+ vae=vae,
132
+ text_encoder=text_encoder,
133
+ text_encoder_2=text_encoder_2,
134
+ tokenizer=tokenizer,
135
+ tokenizer_2=tokenizer_2,
136
+ transformer=transformer,
137
+ scheduler=scheduler,
138
+ image_encoder=image_encoder,
139
+ feature_extractor=feature_extractor
140
+ )
141
+ self.default_sample_size = 64
142
+ self.num = num
143
+
144
+ @torch.no_grad()
145
+ def __call__(
146
+ self,
147
+ prompt: Union[str, List[str]] = None,
148
+ prompt_2: Optional[Union[str, List[str]]] = None,
149
+ negative_prompt: Union[str, List[str]] = None,
150
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
151
+ true_cfg_scale: float = 1.0,
152
+ height: Optional[int] = None,
153
+ width: Optional[int] = None,
154
+ num_inference_steps: int = 28,
155
+ sigmas: Optional[List[float]] = None,
156
+ guidance_scale: float = 3.5,
157
+ num_images_per_prompt: Optional[int] = 1,
158
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
159
+ latents: Optional[torch.FloatTensor] = None,
160
+ prompt_embeds: Optional[torch.FloatTensor] = None,
161
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
162
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
163
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
164
+ output_type: Optional[str] = "pil",
165
+ return_dict: bool = True,
166
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
167
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
168
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
169
+ max_sequence_length: int = 512,
170
+ #####
171
+ latents_ref: Optional[torch.Tensor] = None,
172
+ latents_mask: Optional[torch.Tensor] = None,
173
+ return_latents: bool = False,
174
+ image_cfg_scale: float = 0.0,
175
+ ):
176
+ r"""
177
+ Function invoked when calling the pipeline for generation.
178
+
179
+ Args:
180
+ prompt (`str` or `List[str]`, *optional*):
181
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
182
+ instead.
183
+ prompt_2 (`str` or `List[str]`, *optional*):
184
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
185
+ will be used instead.
186
+ negative_prompt (`str` or `List[str]`, *optional*):
187
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
188
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
189
+ not greater than `1`).
190
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
191
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
192
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
193
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
194
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
195
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
196
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
197
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
198
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
199
+ num_inference_steps (`int`, *optional*, defaults to 50):
200
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
201
+ expense of slower inference.
202
+ sigmas (`List[float]`, *optional*):
203
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
204
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
205
+ will be used.
206
+ guidance_scale (`float`, *optional*, defaults to 7.0):
207
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
208
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
209
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
210
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
211
+ usually at the expense of lower image quality.
212
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
213
+ The number of images to generate per prompt.
214
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
215
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
216
+ to make generation deterministic.
217
+ latents (`torch.FloatTensor`, *optional*):
218
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
219
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
220
+ tensor will ge generated by sampling using the supplied random `generator`.
221
+ prompt_embeds (`torch.FloatTensor`, *optional*):
222
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
223
+ provided, text embeddings will be generated from `prompt` input argument.
224
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
225
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
226
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
227
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
228
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
229
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
230
+ argument.
231
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
232
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
233
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
234
+ input argument.
235
+ output_type (`str`, *optional*, defaults to `"pil"`):
236
+ The output format of the generate image. Choose between
237
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
238
+ return_dict (`bool`, *optional*, defaults to `True`):
239
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
240
+ joint_attention_kwargs (`dict`, *optional*):
241
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
242
+ `self.processor` in
243
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
244
+ callback_on_step_end (`Callable`, *optional*):
245
+ A function that calls at the end of each denoising steps during the inference. The function is called
246
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
247
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
248
+ `callback_on_step_end_tensor_inputs`.
249
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
250
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
251
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
252
+ `._callback_tensor_inputs` attribute of your pipeline class.
253
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
254
+
255
+ Examples:
256
+
257
+ Returns:
258
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
259
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
260
+ images.
261
+ """
262
+
263
+ height = height or self.default_sample_size * self.vae_scale_factor
264
+ width = width or self.default_sample_size * self.vae_scale_factor
265
+
266
+ # 1. Check inputs. Raise error if not correct
267
+ self.check_inputs(
268
+ prompt,
269
+ prompt_2,
270
+ height,
271
+ width,
272
+ negative_prompt=negative_prompt,
273
+ negative_prompt_2=negative_prompt_2,
274
+ prompt_embeds=prompt_embeds,
275
+ negative_prompt_embeds=negative_prompt_embeds,
276
+ pooled_prompt_embeds=pooled_prompt_embeds,
277
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
278
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
279
+ max_sequence_length=max_sequence_length,
280
+ )
281
+
282
+ self._guidance_scale = guidance_scale
283
+ self._joint_attention_kwargs = joint_attention_kwargs
284
+ self._current_timestep = None
285
+ self._interrupt = False
286
+
287
+ # 2. Define call parameters
288
+ if prompt is not None and isinstance(prompt, str):
289
+ batch_size = 1
290
+ elif prompt is not None and isinstance(prompt, list):
291
+ batch_size = len(prompt)
292
+ else:
293
+ batch_size = prompt_embeds.shape[0]
294
+
295
+ device = self._execution_device
296
+
297
+ lora_scale = (
298
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
299
+ )
300
+ has_neg_prompt = negative_prompt is not None or (
301
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
302
+ )
303
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
304
+ (
305
+ prompt_embeds,
306
+ pooled_prompt_embeds,
307
+ text_ids,
308
+ ) = self.encode_prompt(
309
+ prompt=prompt,
310
+ prompt_2=prompt_2,
311
+ prompt_embeds=prompt_embeds,
312
+ pooled_prompt_embeds=pooled_prompt_embeds,
313
+ device=device,
314
+ num_images_per_prompt=num_images_per_prompt,
315
+ max_sequence_length=max_sequence_length,
316
+ lora_scale=lora_scale,
317
+ )
318
+ if do_true_cfg:
319
+ (
320
+ negative_prompt_embeds,
321
+ negative_pooled_prompt_embeds,
322
+ _,
323
+ ) = self.encode_prompt(
324
+ prompt=negative_prompt,
325
+ prompt_2=negative_prompt_2,
326
+ prompt_embeds=negative_prompt_embeds,
327
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
328
+ device=device,
329
+ num_images_per_prompt=num_images_per_prompt,
330
+ max_sequence_length=max_sequence_length,
331
+ lora_scale=lora_scale,
332
+ )
333
+
334
+ # 4. Prepare latent variables
335
+ num_channels_latents = self.transformer.config.in_channels // 4
336
+ latents, latent_image_ids = self.prepare_latents(
337
+ batch_size * num_images_per_prompt,
338
+ num_channels_latents,
339
+ height,
340
+ width,
341
+ prompt_embeds.dtype,
342
+ device,
343
+ generator,
344
+ latents,
345
+ )
346
+
347
+ # 5. Prepare timesteps
348
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
349
+ image_seq_len = latents.shape[1]
350
+ mu = calculate_shift(
351
+ image_seq_len,
352
+ self.scheduler.config.get("base_image_seq_len", 256),
353
+ self.scheduler.config.get("max_image_seq_len", 4096),
354
+ self.scheduler.config.get("base_shift", 0.5),
355
+ self.scheduler.config.get("max_shift", 1.15),
356
+ )
357
+ timesteps, num_inference_steps = retrieve_timesteps(
358
+ self.scheduler,
359
+ num_inference_steps,
360
+ device,
361
+ sigmas=sigmas,
362
+ mu=mu,
363
+ )
364
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
365
+ self._num_timesteps = len(timesteps)
366
+
367
+ # handle guidance
368
+ if self.transformer.config.guidance_embeds:
369
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
370
+ guidance = guidance.expand(latents.shape[0])
371
+ else:
372
+ guidance = None
373
+
374
+ if self.joint_attention_kwargs is None:
375
+ self._joint_attention_kwargs = {}
376
+
377
+ # 6. Denoising loop
378
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
379
+ for i, t in enumerate(timesteps):
380
+ if self.interrupt:
381
+ continue
382
+
383
+ self._current_timestep = t
384
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
385
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
386
+ self.joint_attention_kwargs.update({'timestep': t/1000})
387
+ if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
388
+ latents = (1 - latents_mask) * latents_ref + latents_mask * latents
389
+
390
+ noise_pred = self.transformer(
391
+ hidden_states=latents,
392
+ timestep=timestep / 1000,
393
+ guidance=guidance,
394
+ pooled_projections=pooled_prompt_embeds,
395
+ encoder_hidden_states=prompt_embeds,
396
+ txt_ids=text_ids,
397
+ img_ids=latent_image_ids,
398
+ joint_attention_kwargs=self.joint_attention_kwargs,
399
+ return_dict=False,
400
+ )[0]
401
+
402
+ if do_true_cfg and i>=1:
403
+ neg_noise_pred = self.transformer(
404
+ hidden_states=latents,
405
+ timestep=timestep / 1000,
406
+ guidance=guidance,
407
+ pooled_projections=negative_pooled_prompt_embeds,
408
+ encoder_hidden_states=negative_prompt_embeds,
409
+ txt_ids=text_ids,
410
+ img_ids=latent_image_ids,
411
+ joint_attention_kwargs={**self.joint_attention_kwargs, 'neg_mode': True},
412
+ return_dict=False,
413
+ )[0]
414
+
415
+ if image_cfg_scale > 0:
416
+ image_noise_pred = self.transformer(
417
+ hidden_states=latents,
418
+ timestep=timestep / 1000,
419
+ guidance=guidance,
420
+ pooled_projections=negative_pooled_prompt_embeds,
421
+ encoder_hidden_states=negative_prompt_embeds,
422
+ txt_ids=text_ids,
423
+ img_ids=latent_image_ids,
424
+ joint_attention_kwargs=self.joint_attention_kwargs,
425
+ return_dict=False,
426
+ )[0]
427
+
428
+ if image_cfg_scale == 0:
429
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
430
+ else:
431
+ noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
432
+
433
+ # compute the previous noisy sample x_t -> x_t-1
434
+ latents_dtype = latents.dtype
435
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
436
+
437
+ if latents.dtype != latents_dtype:
438
+ if torch.backends.mps.is_available():
439
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
440
+ latents = latents.to(latents_dtype)
441
+
442
+ if callback_on_step_end is not None:
443
+ callback_kwargs = {}
444
+ for k in callback_on_step_end_tensor_inputs:
445
+ callback_kwargs[k] = locals()[k]
446
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
447
+
448
+ latents = callback_outputs.pop("latents", latents)
449
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
450
+
451
+ # call the callback, if provided
452
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
453
+ progress_bar.update()
454
+
455
+ if XLA_AVAILABLE:
456
+ xm.mark_step()
457
+
458
+ self._current_timestep = None
459
+
460
+ if output_type == "latent":
461
+ image = latents
462
+ else:
463
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
464
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
465
+ image = self.vae.decode(latents, return_dict=False)
466
+
467
+ # Offload all models
468
+ self.maybe_free_model_hooks()
469
+
470
+ return (image,)
pipelines/flux_pipeline/transformer.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bghira/SimpleTuner/blob/d0b5f37913a80aabdb0cac893937072dfa3e6a4b/helpers/models/flux/transformer.py#L404
2
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
3
+ #
4
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
5
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
6
+
7
+ import math
8
+ from contextlib import contextmanager
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
16
+ from diffusers.models.attention import FeedForward
17
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
18
+ from diffusers.models.embeddings import (
19
+ CombinedTimestepGuidanceTextProjEmbeddings,
20
+ CombinedTimestepTextProjEmbeddings,
21
+ FluxPosEmbed,
22
+ )
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import (
26
+ AdaLayerNormContinuous,
27
+ AdaLayerNormZero,
28
+ AdaLayerNormZeroSingle,
29
+ )
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ is_torch_version,
33
+ logging,
34
+ scale_lora_layers,
35
+ unscale_lora_layers,
36
+ )
37
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
38
+ from einops import rearrange
39
+ from peft.tuners.lora.layer import LoraLayer
40
+
41
+ # Import flex_attention for optimized attention with fixed masks
42
+ try:
43
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
44
+ FLEX_ATTENTION_AVAILABLE = True
45
+ except ImportError:
46
+ FLEX_ATTENTION_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ flex_attention_func = None
51
+ block_mask = None
52
+
53
+ class FluxAttnProcessor2_0:
54
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
55
+
56
+ def __init__(self):
57
+ if not hasattr(F, "scaled_dot_product_attention"):
58
+ raise ImportError(
59
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
60
+ )
61
+ self.name = None
62
+
63
+ def __call__(
64
+ self,
65
+ attn: Attention,
66
+ hidden_states: torch.FloatTensor,
67
+ encoder_hidden_states: torch.FloatTensor = None,
68
+ attention_mask: Optional[torch.FloatTensor] = None,
69
+ image_rotary_emb: Optional[torch.Tensor] = None,
70
+ shared_attn: bool = False, num=2,
71
+ scale: float = 1.0,
72
+ timestep: float = 0,
73
+ neg_mode: bool = False,
74
+ ) -> torch.FloatTensor:
75
+
76
+ batch_size, _, _ = (
77
+ hidden_states.shape
78
+ if encoder_hidden_states is None
79
+ else encoder_hidden_states.shape
80
+ )
81
+ end_of_hidden_states = hidden_states.shape[1]
82
+ text_seq = 512
83
+ mask = None
84
+ query = attn.to_q(hidden_states)
85
+ key = attn.to_k(hidden_states)
86
+ value = attn.to_v(hidden_states)
87
+
88
+ inner_dim = key.shape[-1]
89
+ head_dim = inner_dim // attn.heads
90
+
91
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
92
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
93
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
94
+
95
+ if attn.norm_q is not None:
96
+ query = attn.norm_q(query)
97
+ if attn.norm_k is not None:
98
+ key = attn.norm_k(key)
99
+
100
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
101
+ if encoder_hidden_states is not None:
102
+ # `context` projections.
103
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
104
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
105
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
106
+
107
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
108
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
109
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
110
+
111
+ if attn.norm_added_q is not None:
112
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
113
+ if attn.norm_added_k is not None:
114
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
115
+
116
+ # attention
117
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
118
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
119
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
120
+
121
+ if image_rotary_emb is not None:
122
+ from diffusers.models.embeddings import apply_rotary_emb
123
+ query = apply_rotary_emb(query, image_rotary_emb).to(hidden_states.dtype)
124
+ key = apply_rotary_emb(key, image_rotary_emb).to(hidden_states.dtype)
125
+
126
+ if neg_mode and FLEX_ATTENTION_AVAILABLE:
127
+ # Apply flex_attention with the block mask
128
+ global block_mask
129
+ need_new_mask = block_mask is None
130
+
131
+ if need_new_mask:
132
+ res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
133
+ seq_len = query.shape[2]
134
+
135
+ def block_diagonal_mask(b, h, q_idx, kv_idx):
136
+ text_offset = 512
137
+ # Text tokens (first 512) can attend to everything
138
+ # Use tensor operations instead of if statements
139
+ is_text = (q_idx < text_offset) | (kv_idx < text_offset)
140
+
141
+ # For spatial tokens, compute which block they belong to
142
+ q_spatial = q_idx - text_offset
143
+ kv_spatial = kv_idx - text_offset
144
+
145
+ # Determine block indices
146
+ q_block = (q_spatial // res) % num
147
+ kv_block = (kv_spatial // res) % num
148
+
149
+ # Only attend within the same block
150
+ same_block = (q_block == kv_block)
151
+
152
+ # Return: text can attend to everything OR same block
153
+ return is_text | same_block
154
+
155
+ # Create block mask for efficiency
156
+ block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
157
+ Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
158
+
159
+ hidden_states = flex_attention(query, key, value, block_mask=block_mask)
160
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
161
+ elif neg_mode:
162
+ # Fallback to original implementation if flex_attention is not available
163
+ res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
164
+ hw = res*res
165
+ mask_ = torch.zeros(1, res, num*res, res, num*res).to(query.device)
166
+ for i in range(num):
167
+ mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
168
+ mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
169
+ mask = torch.ones(1, num*hw + 512, num*hw + 512, device=query.device, dtype=query.dtype)
170
+ mask[:, 512:, 512:] = mask_
171
+ mask = mask.bool()
172
+ mask = rearrange(mask.unsqueeze(0).expand(attn.heads, -1, -1, -1), "nh b ... -> b nh ...")
173
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
174
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
175
+ else:
176
+ # No masking needed
177
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
178
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
179
+
180
+ hidden_states = hidden_states.to(query.dtype)
181
+
182
+ if encoder_hidden_states is not None:
183
+ encoder_hidden_states, hidden_states = (
184
+ hidden_states[:, : encoder_hidden_states.shape[1]],
185
+ hidden_states[:, encoder_hidden_states.shape[1]:],
186
+ )
187
+ hidden_states = hidden_states[:, :end_of_hidden_states]
188
+
189
+ # linear proj
190
+ hidden_states = attn.to_out[0](hidden_states)
191
+ # dropout
192
+ hidden_states = attn.to_out[1](hidden_states)
193
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
194
+ return hidden_states, encoder_hidden_states
195
+ else:
196
+ return hidden_states[:, :end_of_hidden_states]
197
+
198
+
199
+ def expand_flux_attention_mask(
200
+ hidden_states: torch.Tensor,
201
+ attn_mask: torch.Tensor,
202
+ ) -> torch.Tensor:
203
+ """
204
+ Expand a mask so that the image is included.
205
+ """
206
+ bsz = attn_mask.shape[0]
207
+ assert bsz == hidden_states.shape[0]
208
+ residual_seq_len = hidden_states.shape[1]
209
+ mask_seq_len = attn_mask.shape[1]
210
+
211
+ expanded_mask = torch.ones(bsz, residual_seq_len)
212
+ expanded_mask[:, :mask_seq_len] = attn_mask
213
+
214
+ return expanded_mask
215
+
216
+
217
+ @maybe_allow_in_graph
218
+ class FluxSingleTransformerBlock(nn.Module):
219
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
220
+ super().__init__()
221
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
222
+
223
+ self.norm = AdaLayerNormZeroSingle(dim)
224
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
225
+ self.act_mlp = nn.GELU(approximate="tanh")
226
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
227
+
228
+ processor = FluxAttnProcessor2_0()
229
+ # processor = FluxSingleAttnProcessor3_0()
230
+
231
+ self.attn = Attention(
232
+ query_dim=dim,
233
+ cross_attention_dim=None,
234
+ dim_head=attention_head_dim,
235
+ heads=num_attention_heads,
236
+ out_dim=dim,
237
+ bias=True,
238
+ processor=processor,
239
+ qk_norm="rms_norm",
240
+ eps=1e-6,
241
+ pre_only=True,
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.FloatTensor,
247
+ temb: torch.FloatTensor,
248
+ image_rotary_emb=None,
249
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
250
+ ):
251
+ dtype = hidden_states.dtype
252
+ residual = hidden_states
253
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
254
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
255
+
256
+ attn_output = self.attn(
257
+ hidden_states=norm_hidden_states.to(dtype),
258
+ image_rotary_emb=image_rotary_emb,
259
+ **joint_attention_kwargs,
260
+ )
261
+
262
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
263
+ gate = gate.unsqueeze(1)
264
+ hidden_states = gate * self.proj_out(hidden_states)
265
+ hidden_states = residual + hidden_states
266
+
267
+ return hidden_states
268
+
269
+
270
+ @maybe_allow_in_graph
271
+ class FluxTransformerBlock(nn.Module):
272
+ def __init__(
273
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
274
+ ):
275
+ super().__init__()
276
+
277
+ self.norm1 = AdaLayerNormZero(dim)
278
+
279
+ self.norm1_context = AdaLayerNormZero(dim)
280
+
281
+ if hasattr(F, "scaled_dot_product_attention"):
282
+ processor = FluxAttnProcessor2_0()
283
+ else:
284
+ raise ValueError(
285
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
286
+ )
287
+ self.attn = Attention(
288
+ query_dim=dim,
289
+ cross_attention_dim=None,
290
+ added_kv_proj_dim=dim,
291
+ dim_head=attention_head_dim,
292
+ heads=num_attention_heads,
293
+ out_dim=dim,
294
+ context_pre_only=False,
295
+ bias=True,
296
+ processor=processor,
297
+ qk_norm=qk_norm,
298
+ eps=eps,
299
+ )
300
+
301
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
302
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
303
+
304
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
+ self.ff_context = FeedForward(
306
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
307
+ )
308
+
309
+ # let chunk size default to None
310
+ self._chunk_size = None
311
+ self._chunk_dim = 0
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.FloatTensor,
316
+ encoder_hidden_states: torch.FloatTensor,
317
+ temb: torch.FloatTensor,
318
+ image_rotary_emb=None,
319
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None
320
+ ):
321
+ dtype = hidden_states.dtype
322
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
323
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (self.norm1_context(encoder_hidden_states, emb=temb))
324
+
325
+ # Attention.
326
+ attn_output, context_attn_output = self.attn(
327
+ hidden_states=norm_hidden_states.to(dtype),
328
+ encoder_hidden_states=norm_encoder_hidden_states.to(dtype),
329
+ image_rotary_emb=image_rotary_emb,
330
+ **joint_attention_kwargs,
331
+ )
332
+
333
+ # Process attention outputs for the `hidden_states`.
334
+ attn_output = gate_msa.unsqueeze(1) * attn_output
335
+ hidden_states = hidden_states + attn_output
336
+
337
+ norm_hidden_states = self.norm2(hidden_states)
338
+ norm_hidden_states = (norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None])
339
+
340
+ ff_output = self.ff(norm_hidden_states)
341
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
342
+
343
+ hidden_states = hidden_states + ff_output
344
+
345
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
346
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
347
+
348
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
349
+ norm_encoder_hidden_states = (
350
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
351
+ + c_shift_mlp[:, None]
352
+ )
353
+
354
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
355
+ encoder_hidden_states = (
356
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
357
+ )
358
+
359
+ return encoder_hidden_states, hidden_states
360
+
361
+
362
+ @contextmanager
363
+ def set_adapter_scale(model, alpha):
364
+ original_scaling = {}
365
+ for module in model.modules():
366
+ if isinstance(module, LoraLayer):
367
+ original_scaling[module] = module.scaling.copy()
368
+ module.scaling = {k: v * alpha for k, v in module.scaling.items()}
369
+
370
+ # check whether scaling is prohibited on model
371
+ # the original scaling dictionary should be empty
372
+ # if there were no lora layers
373
+ if not original_scaling:
374
+ raise ValueError("scaling is only supported for models with `LoraLayer`s")
375
+ try:
376
+ yield
377
+
378
+ finally:
379
+ # restore original scaling values after exiting the context
380
+ for module, scaling in original_scaling.items():
381
+ module.scaling = scaling
382
+
383
+
384
+ class FluxTransformer2DModelWithMasking(
385
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
386
+ ):
387
+ """
388
+ The Transformer model introduced in Flux.
389
+
390
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
391
+
392
+ Parameters:
393
+ patch_size (`int`): Patch size to turn the input data into small patches.
394
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
395
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
396
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
397
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
398
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
399
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
400
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
401
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
402
+ """
403
+
404
+ _supports_gradient_checkpointing = True
405
+
406
+ @register_to_config
407
+ def __init__(
408
+ self,
409
+ patch_size: int = 1,
410
+ in_channels: int = 64,
411
+ num_layers: int = 19,
412
+ num_single_layers: int = 38,
413
+ attention_head_dim: int = 128,
414
+ num_attention_heads: int = 24,
415
+ joint_attention_dim: int = 4096,
416
+ pooled_projection_dim: int = 768,
417
+ guidance_embeds: bool = False,
418
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
419
+ ##
420
+ ):
421
+ super().__init__()
422
+ self.out_channels = in_channels
423
+ self.inner_dim = (
424
+ self.config.num_attention_heads * self.config.attention_head_dim
425
+ )
426
+
427
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
428
+ text_time_guidance_cls = (
429
+ CombinedTimestepGuidanceTextProjEmbeddings
430
+ if guidance_embeds
431
+ else CombinedTimestepTextProjEmbeddings
432
+ )
433
+ self.time_text_embed = text_time_guidance_cls(
434
+ embedding_dim=self.inner_dim,
435
+ pooled_projection_dim=self.config.pooled_projection_dim,
436
+ )
437
+
438
+ self.context_embedder = nn.Linear(
439
+ self.config.joint_attention_dim, self.inner_dim
440
+ )
441
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
442
+
443
+ self.transformer_blocks = nn.ModuleList(
444
+ [
445
+ FluxTransformerBlock(
446
+ dim=self.inner_dim,
447
+ num_attention_heads=self.config.num_attention_heads,
448
+ attention_head_dim=self.config.attention_head_dim,
449
+ )
450
+ for i in range(self.config.num_layers)
451
+ ]
452
+ )
453
+
454
+ self.single_transformer_blocks = nn.ModuleList(
455
+ [
456
+ FluxSingleTransformerBlock(
457
+ dim=self.inner_dim,
458
+ num_attention_heads=self.config.num_attention_heads,
459
+ attention_head_dim=self.config.attention_head_dim,
460
+ )
461
+ for i in range(self.config.num_single_layers)
462
+ ]
463
+ )
464
+
465
+ self.norm_out = AdaLayerNormContinuous(
466
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
467
+ )
468
+ self.proj_out = nn.Linear(
469
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
470
+ )
471
+
472
+ self.gradient_checkpointing = False
473
+
474
+ @property
475
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
476
+ r"""
477
+ Returns:
478
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
479
+ indexed by its weight name.
480
+ """
481
+ # set recursively
482
+ processors = {}
483
+
484
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
485
+ if hasattr(module, "get_processor"):
486
+ processors[f"{name}.processor"] = module.get_processor()
487
+
488
+ for sub_name, child in module.named_children():
489
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
490
+
491
+ return processors
492
+
493
+ for name, module in self.named_children():
494
+ fn_recursive_add_processors(name, module, processors)
495
+
496
+ return processors
497
+
498
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
499
+ r"""
500
+ Sets the attention processor to use to compute attention.
501
+
502
+ Parameters:
503
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
504
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
505
+ for **all** `Attention` layers.
506
+
507
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
508
+ processor. This is strongly recommended when setting trainable attention processors.
509
+
510
+ """
511
+ count = len(self.attn_processors.keys())
512
+
513
+ if isinstance(processor, dict) and len(processor) != count:
514
+ raise ValueError(
515
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
516
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
517
+ )
518
+
519
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
520
+ if hasattr(module, "set_processor"):
521
+ if not isinstance(processor, dict):
522
+ module.set_processor(processor)
523
+ else:
524
+ module.set_processor(processor.pop(f"{name}.processor"))
525
+
526
+ for sub_name, child in module.named_children():
527
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
528
+
529
+ for name, module in self.named_children():
530
+ fn_recursive_attn_processor(name, module, processor)
531
+
532
+ def _set_gradient_checkpointing(self, module, value=False):
533
+ if hasattr(module, "gradient_checkpointing"):
534
+ module.gradient_checkpointing = value
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.Tensor,
539
+ encoder_hidden_states: torch.Tensor = None,
540
+ pooled_projections: torch.Tensor = None,
541
+ timestep: torch.LongTensor = None,
542
+ img_ids: torch.Tensor = None,
543
+ txt_ids: torch.Tensor = None,
544
+ guidance: torch.Tensor = None,
545
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
546
+ return_dict: bool = True,
547
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
548
+ """
549
+ The [`FluxTransformer2DModelWithMasking`] forward method.
550
+
551
+ Args:
552
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
553
+ Input `hidden_states`.
554
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
555
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
556
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
557
+ from the embeddings of input conditions.
558
+ timestep ( `torch.LongTensor`):
559
+ Used to indicate denoising step.
560
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
561
+ A list of tensors that if specified are added to the residuals of transformer blocks.
562
+ joint_attention_kwargs (`dict`, *optional*):
563
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
564
+ `self.processor` in
565
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
566
+ return_dict (`bool`, *optional*, defaults to `True`):
567
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
568
+ tuple.
569
+
570
+ Returns:
571
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
572
+ `tuple` where the first element is the sample tensor.
573
+ """
574
+ if joint_attention_kwargs is not None:
575
+ joint_attention_kwargs = joint_attention_kwargs.copy()
576
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
577
+ else:
578
+ lora_scale = 1.0
579
+
580
+ if USE_PEFT_BACKEND:
581
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
582
+ scale_lora_layers(self, lora_scale)
583
+ else:
584
+ if (
585
+ joint_attention_kwargs is not None
586
+ and joint_attention_kwargs.get("scale", None) is not None
587
+ ):
588
+ logger.warning(
589
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
590
+ )
591
+ hidden_states = self.x_embedder(hidden_states)
592
+
593
+ timestep = timestep.to(hidden_states.dtype) * 1000
594
+ if guidance is not None:
595
+ guidance = guidance.to(hidden_states.dtype) * 1000
596
+ else:
597
+ guidance = None
598
+ temb = (
599
+ self.time_text_embed(timestep, pooled_projections)
600
+ if guidance is None
601
+ else self.time_text_embed(timestep, guidance, pooled_projections)
602
+ )
603
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
604
+
605
+ if txt_ids.ndim == 3:
606
+ txt_ids = txt_ids[0]
607
+ if img_ids.ndim == 3:
608
+ img_ids = img_ids[0]
609
+
610
+ ids = torch.cat((txt_ids, img_ids), dim=0).to(hidden_states.dtype)
611
+
612
+ image_rotary_emb = self.pos_embed(ids)
613
+
614
+ for index_block, block in enumerate(self.transformer_blocks):
615
+ if self.training and self.gradient_checkpointing:
616
+
617
+ def create_custom_forward(module, return_dict=None):
618
+ def custom_forward(*inputs):
619
+ if return_dict is not None:
620
+ return module(*inputs, return_dict=return_dict)
621
+ else:
622
+ return module(*inputs)
623
+
624
+ return custom_forward
625
+
626
+ ckpt_kwargs: Dict[str, Any] = (
627
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
628
+ )
629
+ encoder_hidden_states, hidden_states = (
630
+ torch.utils.checkpoint.checkpoint(
631
+ create_custom_forward(block),
632
+ hidden_states,
633
+ encoder_hidden_states,
634
+ temb,
635
+ image_rotary_emb,
636
+ joint_attention_kwargs,
637
+ **ckpt_kwargs,
638
+ )
639
+ )
640
+
641
+ else:
642
+ encoder_hidden_states, hidden_states = block(
643
+ hidden_states=hidden_states,
644
+ encoder_hidden_states=encoder_hidden_states,
645
+ temb=temb,
646
+ image_rotary_emb=image_rotary_emb,
647
+ joint_attention_kwargs=joint_attention_kwargs,
648
+ )
649
+
650
+ # Flux places the text tokens in front of the image tokens in the
651
+ # sequence.
652
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
653
+
654
+ for index_block, block in enumerate(self.single_transformer_blocks):
655
+ if self.training and self.gradient_checkpointing:
656
+
657
+ def create_custom_forward(module, return_dict=None):
658
+ def custom_forward(*inputs):
659
+ if return_dict is not None:
660
+ return module(*inputs, return_dict=return_dict)
661
+ else:
662
+ return module(*inputs)
663
+
664
+ return custom_forward
665
+
666
+ ckpt_kwargs: Dict[str, Any] = (
667
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
668
+ )
669
+ hidden_states = torch.utils.checkpoint.checkpoint(
670
+ create_custom_forward(block),
671
+ hidden_states,
672
+ temb,
673
+ image_rotary_emb,
674
+ joint_attention_kwargs,
675
+ **ckpt_kwargs,
676
+ )
677
+
678
+ else:
679
+ hidden_states = block(
680
+ hidden_states=hidden_states,
681
+ temb=temb,
682
+ image_rotary_emb=image_rotary_emb,
683
+ joint_attention_kwargs=joint_attention_kwargs,
684
+ )
685
+
686
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
687
+
688
+ hidden_states = self.norm_out(hidden_states, temb)
689
+ output = self.proj_out(hidden_states)
690
+
691
+ if USE_PEFT_BACKEND:
692
+ # remove `lora_scale` from each PEFT layer
693
+ unscale_lora_layers(self, lora_scale)
694
+
695
+ if not return_dict:
696
+ return (output,)
697
+
698
+ return Transformer2DModelOutput(sample=output)
699
+
700
+
701
+ if __name__ == "__main__":
702
+ dtype = torch.bfloat16
703
+ bsz = 2
704
+ img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
705
+ timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
706
+ pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
707
+ text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
708
+ attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
709
+ "cuda", dtype=dtype
710
+ ) # Last 128 positions are masked
711
+
712
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
713
+ latents = latents.view(
714
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
715
+ )
716
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
717
+ latents = latents.reshape(
718
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
719
+ )
720
+
721
+ return latents
722
+
723
+ def _prepare_latent_image_ids(
724
+ batch_size, height, width, device="cuda", dtype=dtype
725
+ ):
726
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
727
+ latent_image_ids[..., 1] = (
728
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
729
+ )
730
+ latent_image_ids[..., 2] = (
731
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
732
+ )
733
+
734
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
735
+ latent_image_ids.shape
736
+ )
737
+
738
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
739
+ latent_image_ids = latent_image_ids.reshape(
740
+ batch_size,
741
+ latent_image_id_height * latent_image_id_width,
742
+ latent_image_id_channels,
743
+ )
744
+
745
+ return latent_image_ids.to(device=device, dtype=dtype)
746
+
747
+ txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
748
+
749
+ vae_scale_factor = 16
750
+ height = 2 * (int(512) // vae_scale_factor)
751
+ width = 2 * (int(512) // vae_scale_factor)
752
+ img_ids = _prepare_latent_image_ids(bsz, height, width)
753
+ img = _pack_latents(img, img.shape[0], 16, height, width)
754
+
755
+ # Gotta go fast
756
+ transformer = FluxTransformer2DModelWithMasking.from_config(
757
+ {
758
+ "attention_head_dim": 128,
759
+ "guidance_embeds": True,
760
+ "in_channels": 64,
761
+ "joint_attention_dim": 4096,
762
+ "num_attention_heads": 24,
763
+ "num_layers": 4,
764
+ "num_single_layers": 8,
765
+ "patch_size": 1,
766
+ "pooled_projection_dim": 768,
767
+ }
768
+ ).to("cuda", dtype=dtype)
769
+
770
+ guidance = torch.tensor([2.0], device="cuda")
771
+ guidance = guidance.expand(bsz)
772
+
773
+ with torch.no_grad():
774
+ no_mask = transformer(
775
+ img,
776
+ encoder_hidden_states=text,
777
+ pooled_projections=pooled,
778
+ timestep=timestep,
779
+ img_ids=img_ids,
780
+ txt_ids=txt_ids,
781
+ guidance=guidance,
782
+ )
783
+ mask = transformer(
784
+ img,
785
+ encoder_hidden_states=text,
786
+ pooled_projections=pooled,
787
+ timestep=timestep,
788
+ img_ids=img_ids,
789
+ txt_ids=txt_ids,
790
+ guidance=guidance,
791
+ attention_mask=attn_mask,
792
+ )
793
+
794
+ assert torch.allclose(no_mask.sample, mask.sample) is False
795
+ print("Attention masking test ran OK. Differences in output were detected.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ torch
3
+ transformers
4
+ peft
5
+ einops
6
+ numpy
7
+ Pillow
8
+ sentencepiece
9
+ huggingface_hub