Spaces:
Runtime error
Runtime error
Upload 15 files
Browse files- .gitattributes +3 -0
- README.md +15 -7
- app.py +287 -0
- imgs/test_cases/action_figure/0.jpg +0 -0
- imgs/test_cases/action_figure/1.jpg +0 -0
- imgs/test_cases/action_figure/2.jpg +0 -0
- imgs/test_cases/penguin/0.jpg +0 -0
- imgs/test_cases/penguin/1.jpg +0 -0
- imgs/test_cases/penguin/2.jpg +0 -0
- imgs/test_cases/rc_car/02.jpg +3 -0
- imgs/test_cases/rc_car/03.jpg +3 -0
- imgs/test_cases/rc_car/04.jpg +3 -0
- models/pytorch_model.bin +3 -0
- pipelines/flux_pipeline/pipeline.py +470 -0
- pipelines/flux_pipeline/transformer.py +795 -0
- requirements.txt +9 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
imgs/test_cases/rc_car/02.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
imgs/test_cases/rc_car/03.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
imgs/test_cases/rc_car/04.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,20 @@
|
|
| 1 |
---
|
| 2 |
-
title: SynCD
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SynCD
|
| 3 |
+
emoji: 🖼
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.17.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
tags:
|
| 12 |
+
- dwpose
|
| 13 |
+
- pose
|
| 14 |
+
- Text-to-Image
|
| 15 |
+
- Image-to-Image
|
| 16 |
+
- language models
|
| 17 |
+
- LLMs
|
| 18 |
+
short_description: Image generator/customization/personalization
|
| 19 |
---
|
| 20 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
app.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import numpy as np
|
| 6 |
+
import spaces
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from huggingface_hub import login
|
| 10 |
+
from peft import LoraConfig
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline
|
| 13 |
+
from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking
|
| 14 |
+
|
| 15 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 16 |
+
login(token=HF_TOKEN)
|
| 17 |
+
torch_dtype = torch.bfloat16
|
| 18 |
+
transformer = FluxTransformer2DModelWithMasking.from_pretrained(
|
| 19 |
+
'black-forest-labs/FLUX.1-dev',
|
| 20 |
+
subfolder='transformer',
|
| 21 |
+
torch_dtype=torch_dtype
|
| 22 |
+
)
|
| 23 |
+
pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype)
|
| 24 |
+
for name, attn_proc in pipeline.transformer.attn_processors.items():
|
| 25 |
+
attn_proc.name = name
|
| 26 |
+
|
| 27 |
+
target_modules=[
|
| 28 |
+
"to_k",
|
| 29 |
+
"to_q",
|
| 30 |
+
"to_v",
|
| 31 |
+
"add_k_proj",
|
| 32 |
+
"add_q_proj",
|
| 33 |
+
"add_v_proj",
|
| 34 |
+
"to_out.0",
|
| 35 |
+
"to_add_out",
|
| 36 |
+
"ff.net.0.proj",
|
| 37 |
+
"ff.net.2",
|
| 38 |
+
"ff_context.net.0.proj",
|
| 39 |
+
"ff_context.net.2",
|
| 40 |
+
"proj_mlp",
|
| 41 |
+
"proj_out",
|
| 42 |
+
]
|
| 43 |
+
lora_rank = 32
|
| 44 |
+
lora_config = LoraConfig(
|
| 45 |
+
r=lora_rank,
|
| 46 |
+
lora_alpha=lora_rank,
|
| 47 |
+
init_lora_weights="gaussian",
|
| 48 |
+
target_modules=target_modules,
|
| 49 |
+
)
|
| 50 |
+
pipeline.transformer.add_adapter(lora_config)
|
| 51 |
+
finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu')
|
| 52 |
+
transformer_dict = {}
|
| 53 |
+
for key,value in finetuned_path.items():
|
| 54 |
+
if 'transformer.base_model.model.' in key:
|
| 55 |
+
transformer_dict[key.replace('transformer.base_model.model.', '')] = value
|
| 56 |
+
pipeline.transformer.load_state_dict(transformer_dict, strict=False)
|
| 57 |
+
pipeline.to('cuda')
|
| 58 |
+
pipeline.enable_vae_slicing()
|
| 59 |
+
pipeline.enable_vae_tiling()
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def decode(latents, pipeline):
|
| 63 |
+
latents = latents / pipeline.vae.config.scaling_factor
|
| 64 |
+
image = pipeline.vae.decode(latents, return_dict=False)[0]
|
| 65 |
+
return image
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def encode_target_images(images, pipeline):
|
| 70 |
+
latents = pipeline.vae.encode(images).latent_dist.sample()
|
| 71 |
+
latents = latents * pipeline.vae.config.scaling_factor
|
| 72 |
+
return latents
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@spaces.GPU(duration=120)
|
| 76 |
+
def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, enable_cpu_offload=False, neg_prompt="", true_cfg=1.0, image_cfg=0.0):
|
| 77 |
+
if neg_prompt == "":
|
| 78 |
+
neg_prompt = "3d render, cartoon, low resolution, illustration, blurry, unrealistic"
|
| 79 |
+
if enable_cpu_offload:
|
| 80 |
+
pipeline.enable_sequential_cpu_offload()
|
| 81 |
+
input_images = [img1, img2, img3]
|
| 82 |
+
# Delete None
|
| 83 |
+
input_images = [img for img in input_images if img is not None]
|
| 84 |
+
if len(input_images) == 0:
|
| 85 |
+
return "Please upload at least one image"
|
| 86 |
+
numref = len(input_images) + 1
|
| 87 |
+
images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images])
|
| 88 |
+
images = images.to(pipeline.device)
|
| 89 |
+
latents = encode_target_images(images, pipeline)
|
| 90 |
+
latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0)
|
| 91 |
+
masklatent = torch.zeros_like(latents)
|
| 92 |
+
masklatent[:1] = 1.
|
| 93 |
+
latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref)
|
| 94 |
+
masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref)
|
| 95 |
+
B, C, H, W = latents.shape
|
| 96 |
+
latents = pipeline._pack_latents(latents, B, C, H, W)
|
| 97 |
+
masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W)
|
| 98 |
+
output = pipeline(
|
| 99 |
+
text,
|
| 100 |
+
latents_ref=latents,
|
| 101 |
+
latents_mask=masklatent,
|
| 102 |
+
guidance_scale=guidance_scale,
|
| 103 |
+
num_inference_steps=inference_steps,
|
| 104 |
+
height=512,
|
| 105 |
+
width=numref * 512,
|
| 106 |
+
generator = torch.Generator(device="cuda").manual_seed(seed),
|
| 107 |
+
joint_attention_kwargs={'shared_attn': True, 'num': numref},
|
| 108 |
+
return_dict=False,
|
| 109 |
+
negative_prompt=neg_prompt,
|
| 110 |
+
true_cfg_scale=true_cfg,
|
| 111 |
+
image_cfg_scale=image_cfg,
|
| 112 |
+
)[0][0]
|
| 113 |
+
output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref]
|
| 114 |
+
img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) )
|
| 115 |
+
return img
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_example():
|
| 120 |
+
case = [
|
| 121 |
+
[
|
| 122 |
+
"An action figure on top of a mountain. Sunset in the background. Realistic shot.",
|
| 123 |
+
"./imgs/test_cases/action_figure/0.jpg",
|
| 124 |
+
"./imgs/test_cases/action_figure/1.jpg",
|
| 125 |
+
"./imgs/test_cases/action_figure/2.jpg",
|
| 126 |
+
3.5,
|
| 127 |
+
42,
|
| 128 |
+
False,
|
| 129 |
+
"",
|
| 130 |
+
1.0,
|
| 131 |
+
0.0,
|
| 132 |
+
],
|
| 133 |
+
[
|
| 134 |
+
"A penguin plushie wearing pink sunglasses is lounging on a beach. Realistic shot.",
|
| 135 |
+
"./imgs/test_cases/penguin/0.jpg",
|
| 136 |
+
"./imgs/test_cases/penguin/1.jpg",
|
| 137 |
+
"./imgs/test_cases/penguin/2.jpg",
|
| 138 |
+
3.5,
|
| 139 |
+
42,
|
| 140 |
+
False,
|
| 141 |
+
"",
|
| 142 |
+
1.0,
|
| 143 |
+
0.0,
|
| 144 |
+
],
|
| 145 |
+
[
|
| 146 |
+
"A toy on a beach. Waves in the background. Realistic shot.",
|
| 147 |
+
"./imgs/test_cases/rc_car/02.jpg",
|
| 148 |
+
"./imgs/test_cases/rc_car/03.jpg",
|
| 149 |
+
"./imgs/test_cases/rc_car/04.jpg",
|
| 150 |
+
3.5,
|
| 151 |
+
42,
|
| 152 |
+
False,
|
| 153 |
+
"",
|
| 154 |
+
1.0,
|
| 155 |
+
0.0,
|
| 156 |
+
],
|
| 157 |
+
]
|
| 158 |
+
return case
|
| 159 |
+
|
| 160 |
+
def run_for_examples(text, img1, img2, img3, guidance_scale, seed, enable_cpu_offload=False, neg_prompt="", true_cfg=1.0, image_cfg=0.0):
|
| 161 |
+
inference_steps = 30
|
| 162 |
+
|
| 163 |
+
return generate_image(
|
| 164 |
+
text, img1, img2, img3, guidance_scale, inference_steps, seed, enable_cpu_offload, neg_prompt, true_cfg, image_cfg
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
description = """
|
| 168 |
+
Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd).
|
| 169 |
+
|
| 170 |
+
Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image.
|
| 171 |
+
|
| 172 |
+
**HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
article = """
|
| 176 |
+
---
|
| 177 |
+
**Citation**
|
| 178 |
+
<br>
|
| 179 |
+
If you find this repository useful, please consider giving a star ⭐ and a citation
|
| 180 |
+
```
|
| 181 |
+
@article{kumari2025syncd,
|
| 182 |
+
title={Generating Multi-Image Synthetic Data for Text-to-Image Customization},
|
| 183 |
+
author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh},
|
| 184 |
+
journal={ArXiv},
|
| 185 |
+
year={2025}
|
| 186 |
+
}
|
| 187 |
+
```
|
| 188 |
+
**Contact**
|
| 189 |
+
<br>
|
| 190 |
+
If you have any questions, please feel free to open an issue or directly reach us out via email.
|
| 191 |
+
|
| 192 |
+
**Acknowledgement**
|
| 193 |
+
<br>
|
| 194 |
+
This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# Gradio
|
| 199 |
+
with gr.Blocks() as demo:
|
| 200 |
+
gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]")
|
| 201 |
+
gr.Markdown(description)
|
| 202 |
+
with gr.Row():
|
| 203 |
+
with gr.Column():
|
| 204 |
+
# text prompt
|
| 205 |
+
prompt_input = gr.Textbox(
|
| 206 |
+
label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with gr.Row(equal_height=True):
|
| 210 |
+
# input images
|
| 211 |
+
image_input_1 = gr.Image(label="img1", type="filepath")
|
| 212 |
+
image_input_2 = gr.Image(label="img2", type="filepath")
|
| 213 |
+
image_input_3 = gr.Image(label="img3", type="filepath")
|
| 214 |
+
|
| 215 |
+
guidance_scale_input = gr.Slider(
|
| 216 |
+
label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
num_inference_steps = gr.Slider(
|
| 220 |
+
label="Inference Steps", minimum=1, maximum=100, value=30, step=1
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
seed_input = gr.Slider(
|
| 224 |
+
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
enable_cpu_offload = gr.Checkbox(
|
| 228 |
+
label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG", open=False): # noqa E501
|
| 232 |
+
neg_prompt = gr.Textbox(
|
| 233 |
+
label="Negative Prompt",
|
| 234 |
+
value="")
|
| 235 |
+
true_cfg = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="true CFG. Recommended to be 1.5")
|
| 236 |
+
image_cfg = gr.Slider(0.0, 10.0, 0.0, step=0.1, label="image CFG scale, will increase the image alignment but longer run time and lower text alignment. Recommended to be 1.0")
|
| 237 |
+
|
| 238 |
+
# generate
|
| 239 |
+
generate_button = gr.Button("Generate Image")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
with gr.Column():
|
| 243 |
+
# output image
|
| 244 |
+
output_image = gr.Image(label="Output Image")
|
| 245 |
+
|
| 246 |
+
# click
|
| 247 |
+
generate_button.click(
|
| 248 |
+
generate_image,
|
| 249 |
+
inputs=[
|
| 250 |
+
prompt_input,
|
| 251 |
+
image_input_1,
|
| 252 |
+
image_input_2,
|
| 253 |
+
image_input_3,
|
| 254 |
+
guidance_scale_input,
|
| 255 |
+
num_inference_steps,
|
| 256 |
+
seed_input,
|
| 257 |
+
enable_cpu_offload,
|
| 258 |
+
neg_prompt,
|
| 259 |
+
true_cfg,
|
| 260 |
+
image_cfg,
|
| 261 |
+
],
|
| 262 |
+
outputs=output_image,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
gr.Examples(
|
| 266 |
+
examples=get_example(),
|
| 267 |
+
fn=run_for_examples,
|
| 268 |
+
inputs=[
|
| 269 |
+
prompt_input,
|
| 270 |
+
image_input_1,
|
| 271 |
+
image_input_2,
|
| 272 |
+
image_input_3,
|
| 273 |
+
guidance_scale_input,
|
| 274 |
+
seed_input,
|
| 275 |
+
enable_cpu_offload,
|
| 276 |
+
neg_prompt,
|
| 277 |
+
true_cfg,
|
| 278 |
+
image_cfg,
|
| 279 |
+
],
|
| 280 |
+
outputs=output_image,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
gr.Markdown(article)
|
| 284 |
+
|
| 285 |
+
# launch
|
| 286 |
+
demo.launch(ssr_mode=False)
|
| 287 |
+
|
imgs/test_cases/action_figure/0.jpg
ADDED
|
imgs/test_cases/action_figure/1.jpg
ADDED
|
imgs/test_cases/action_figure/2.jpg
ADDED
|
imgs/test_cases/penguin/0.jpg
ADDED
|
imgs/test_cases/penguin/1.jpg
ADDED
|
imgs/test_cases/penguin/2.jpg
ADDED
|
imgs/test_cases/rc_car/02.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/rc_car/03.jpg
ADDED
|
Git LFS Details
|
imgs/test_cases/rc_car/04.jpg
ADDED
|
Git LFS Details
|
models/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0de7be527b2bf604f679a8c4a0545af4a371e6559aff8bfa28f2a47510872da9
|
| 3 |
+
size 134
|
pipelines/flux_pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import FluxPipeline
|
| 21 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 22 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
| 23 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.utils import is_torch_xla_available
|
| 25 |
+
from transformers import (
|
| 26 |
+
CLIPImageProcessor,
|
| 27 |
+
CLIPTextModel,
|
| 28 |
+
CLIPTokenizer,
|
| 29 |
+
CLIPVisionModelWithProjection,
|
| 30 |
+
T5EncoderModel,
|
| 31 |
+
T5TokenizerFast,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def calculate_shift(
|
| 43 |
+
image_seq_len,
|
| 44 |
+
base_seq_len: int = 256,
|
| 45 |
+
max_seq_len: int = 4096,
|
| 46 |
+
base_shift: float = 0.5,
|
| 47 |
+
max_shift: float = 1.16,
|
| 48 |
+
):
|
| 49 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 50 |
+
b = base_shift - m * base_seq_len
|
| 51 |
+
mu = image_seq_len * m + b
|
| 52 |
+
return mu
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 56 |
+
def retrieve_timesteps(
|
| 57 |
+
scheduler,
|
| 58 |
+
num_inference_steps: Optional[int] = None,
|
| 59 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 60 |
+
timesteps: Optional[List[int]] = None,
|
| 61 |
+
sigmas: Optional[List[float]] = None,
|
| 62 |
+
**kwargs,):
|
| 63 |
+
if timesteps is not None and sigmas is not None:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
| 66 |
+
)
|
| 67 |
+
if timesteps is not None:
|
| 68 |
+
accepts_timesteps = "timesteps" in set(
|
| 69 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 70 |
+
)
|
| 71 |
+
if not accepts_timesteps:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 74 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 75 |
+
)
|
| 76 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 77 |
+
timesteps = scheduler.timesteps
|
| 78 |
+
num_inference_steps = len(timesteps)
|
| 79 |
+
elif sigmas is not None:
|
| 80 |
+
accept_sigmas = "sigmas" in set(
|
| 81 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 82 |
+
)
|
| 83 |
+
if not accept_sigmas:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 86 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 87 |
+
)
|
| 88 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 89 |
+
timesteps = scheduler.timesteps
|
| 90 |
+
num_inference_steps = len(timesteps)
|
| 91 |
+
else:
|
| 92 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 93 |
+
timesteps = scheduler.timesteps
|
| 94 |
+
return timesteps, num_inference_steps
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale):
|
| 98 |
+
diff_img = image_noise_pred - neg_noise_pred
|
| 99 |
+
diff_txt = noise_pred - image_noise_pred
|
| 100 |
+
|
| 101 |
+
diff_norm_txt = diff_txt.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 102 |
+
diff_norm_img = diff_img.norm(p=2, dim=[-1, -2], keepdim=True)
|
| 103 |
+
min_norm = torch.minimum(diff_norm_img, diff_norm_txt)
|
| 104 |
+
diff_txt = diff_txt * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_txt)
|
| 105 |
+
diff_img = diff_img * torch.minimum(torch.ones_like(diff_txt), min_norm / diff_norm_img)
|
| 106 |
+
pred_guided = image_noise_pred + image_cfg_scale * diff_img + true_cfg_scale * diff_txt
|
| 107 |
+
return pred_guided
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class SynCDFluxPipeline(FluxPipeline):
|
| 111 |
+
|
| 112 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 113 |
+
_optional_components = []
|
| 114 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 119 |
+
vae: AutoencoderKL,
|
| 120 |
+
text_encoder: CLIPTextModel,
|
| 121 |
+
tokenizer: CLIPTokenizer,
|
| 122 |
+
text_encoder_2: T5EncoderModel,
|
| 123 |
+
tokenizer_2: T5TokenizerFast,
|
| 124 |
+
transformer: FluxTransformer2DModel,
|
| 125 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 126 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 127 |
+
###
|
| 128 |
+
num=2,
|
| 129 |
+
):
|
| 130 |
+
super().__init__(
|
| 131 |
+
vae=vae,
|
| 132 |
+
text_encoder=text_encoder,
|
| 133 |
+
text_encoder_2=text_encoder_2,
|
| 134 |
+
tokenizer=tokenizer,
|
| 135 |
+
tokenizer_2=tokenizer_2,
|
| 136 |
+
transformer=transformer,
|
| 137 |
+
scheduler=scheduler,
|
| 138 |
+
image_encoder=image_encoder,
|
| 139 |
+
feature_extractor=feature_extractor
|
| 140 |
+
)
|
| 141 |
+
self.default_sample_size = 64
|
| 142 |
+
self.num = num
|
| 143 |
+
|
| 144 |
+
@torch.no_grad()
|
| 145 |
+
def __call__(
|
| 146 |
+
self,
|
| 147 |
+
prompt: Union[str, List[str]] = None,
|
| 148 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 149 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 150 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 151 |
+
true_cfg_scale: float = 1.0,
|
| 152 |
+
height: Optional[int] = None,
|
| 153 |
+
width: Optional[int] = None,
|
| 154 |
+
num_inference_steps: int = 28,
|
| 155 |
+
sigmas: Optional[List[float]] = None,
|
| 156 |
+
guidance_scale: float = 3.5,
|
| 157 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 158 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 159 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 160 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 161 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 162 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 163 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 164 |
+
output_type: Optional[str] = "pil",
|
| 165 |
+
return_dict: bool = True,
|
| 166 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 167 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 168 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 169 |
+
max_sequence_length: int = 512,
|
| 170 |
+
#####
|
| 171 |
+
latents_ref: Optional[torch.Tensor] = None,
|
| 172 |
+
latents_mask: Optional[torch.Tensor] = None,
|
| 173 |
+
return_latents: bool = False,
|
| 174 |
+
image_cfg_scale: float = 0.0,
|
| 175 |
+
):
|
| 176 |
+
r"""
|
| 177 |
+
Function invoked when calling the pipeline for generation.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 181 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 182 |
+
instead.
|
| 183 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 184 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 185 |
+
will be used instead.
|
| 186 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 187 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 188 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 189 |
+
not greater than `1`).
|
| 190 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 191 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 192 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 193 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 194 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 195 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 196 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 197 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 198 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 199 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 200 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 201 |
+
expense of slower inference.
|
| 202 |
+
sigmas (`List[float]`, *optional*):
|
| 203 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 204 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 205 |
+
will be used.
|
| 206 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 207 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 208 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 209 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 210 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 211 |
+
usually at the expense of lower image quality.
|
| 212 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 213 |
+
The number of images to generate per prompt.
|
| 214 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 215 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 216 |
+
to make generation deterministic.
|
| 217 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 218 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 219 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 220 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 221 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 222 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 223 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 224 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 225 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 226 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 227 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 228 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 229 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 230 |
+
argument.
|
| 231 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 232 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 233 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 234 |
+
input argument.
|
| 235 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 236 |
+
The output format of the generate image. Choose between
|
| 237 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 238 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 239 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 240 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 241 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 242 |
+
`self.processor` in
|
| 243 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 244 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 245 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 246 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 247 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 248 |
+
`callback_on_step_end_tensor_inputs`.
|
| 249 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 250 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 251 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 252 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 253 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 254 |
+
|
| 255 |
+
Examples:
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 259 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 260 |
+
images.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 264 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 265 |
+
|
| 266 |
+
# 1. Check inputs. Raise error if not correct
|
| 267 |
+
self.check_inputs(
|
| 268 |
+
prompt,
|
| 269 |
+
prompt_2,
|
| 270 |
+
height,
|
| 271 |
+
width,
|
| 272 |
+
negative_prompt=negative_prompt,
|
| 273 |
+
negative_prompt_2=negative_prompt_2,
|
| 274 |
+
prompt_embeds=prompt_embeds,
|
| 275 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 276 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 277 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 278 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 279 |
+
max_sequence_length=max_sequence_length,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self._guidance_scale = guidance_scale
|
| 283 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 284 |
+
self._current_timestep = None
|
| 285 |
+
self._interrupt = False
|
| 286 |
+
|
| 287 |
+
# 2. Define call parameters
|
| 288 |
+
if prompt is not None and isinstance(prompt, str):
|
| 289 |
+
batch_size = 1
|
| 290 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 291 |
+
batch_size = len(prompt)
|
| 292 |
+
else:
|
| 293 |
+
batch_size = prompt_embeds.shape[0]
|
| 294 |
+
|
| 295 |
+
device = self._execution_device
|
| 296 |
+
|
| 297 |
+
lora_scale = (
|
| 298 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 299 |
+
)
|
| 300 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 301 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 302 |
+
)
|
| 303 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 304 |
+
(
|
| 305 |
+
prompt_embeds,
|
| 306 |
+
pooled_prompt_embeds,
|
| 307 |
+
text_ids,
|
| 308 |
+
) = self.encode_prompt(
|
| 309 |
+
prompt=prompt,
|
| 310 |
+
prompt_2=prompt_2,
|
| 311 |
+
prompt_embeds=prompt_embeds,
|
| 312 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 313 |
+
device=device,
|
| 314 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 315 |
+
max_sequence_length=max_sequence_length,
|
| 316 |
+
lora_scale=lora_scale,
|
| 317 |
+
)
|
| 318 |
+
if do_true_cfg:
|
| 319 |
+
(
|
| 320 |
+
negative_prompt_embeds,
|
| 321 |
+
negative_pooled_prompt_embeds,
|
| 322 |
+
_,
|
| 323 |
+
) = self.encode_prompt(
|
| 324 |
+
prompt=negative_prompt,
|
| 325 |
+
prompt_2=negative_prompt_2,
|
| 326 |
+
prompt_embeds=negative_prompt_embeds,
|
| 327 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 328 |
+
device=device,
|
| 329 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 330 |
+
max_sequence_length=max_sequence_length,
|
| 331 |
+
lora_scale=lora_scale,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# 4. Prepare latent variables
|
| 335 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 336 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 337 |
+
batch_size * num_images_per_prompt,
|
| 338 |
+
num_channels_latents,
|
| 339 |
+
height,
|
| 340 |
+
width,
|
| 341 |
+
prompt_embeds.dtype,
|
| 342 |
+
device,
|
| 343 |
+
generator,
|
| 344 |
+
latents,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# 5. Prepare timesteps
|
| 348 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 349 |
+
image_seq_len = latents.shape[1]
|
| 350 |
+
mu = calculate_shift(
|
| 351 |
+
image_seq_len,
|
| 352 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 353 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 354 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 355 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 356 |
+
)
|
| 357 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 358 |
+
self.scheduler,
|
| 359 |
+
num_inference_steps,
|
| 360 |
+
device,
|
| 361 |
+
sigmas=sigmas,
|
| 362 |
+
mu=mu,
|
| 363 |
+
)
|
| 364 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 365 |
+
self._num_timesteps = len(timesteps)
|
| 366 |
+
|
| 367 |
+
# handle guidance
|
| 368 |
+
if self.transformer.config.guidance_embeds:
|
| 369 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 370 |
+
guidance = guidance.expand(latents.shape[0])
|
| 371 |
+
else:
|
| 372 |
+
guidance = None
|
| 373 |
+
|
| 374 |
+
if self.joint_attention_kwargs is None:
|
| 375 |
+
self._joint_attention_kwargs = {}
|
| 376 |
+
|
| 377 |
+
# 6. Denoising loop
|
| 378 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 379 |
+
for i, t in enumerate(timesteps):
|
| 380 |
+
if self.interrupt:
|
| 381 |
+
continue
|
| 382 |
+
|
| 383 |
+
self._current_timestep = t
|
| 384 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 385 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 386 |
+
self.joint_attention_kwargs.update({'timestep': t/1000})
|
| 387 |
+
if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
|
| 388 |
+
latents = (1 - latents_mask) * latents_ref + latents_mask * latents
|
| 389 |
+
|
| 390 |
+
noise_pred = self.transformer(
|
| 391 |
+
hidden_states=latents,
|
| 392 |
+
timestep=timestep / 1000,
|
| 393 |
+
guidance=guidance,
|
| 394 |
+
pooled_projections=pooled_prompt_embeds,
|
| 395 |
+
encoder_hidden_states=prompt_embeds,
|
| 396 |
+
txt_ids=text_ids,
|
| 397 |
+
img_ids=latent_image_ids,
|
| 398 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 399 |
+
return_dict=False,
|
| 400 |
+
)[0]
|
| 401 |
+
|
| 402 |
+
if do_true_cfg and i>=1:
|
| 403 |
+
neg_noise_pred = self.transformer(
|
| 404 |
+
hidden_states=latents,
|
| 405 |
+
timestep=timestep / 1000,
|
| 406 |
+
guidance=guidance,
|
| 407 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 408 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 409 |
+
txt_ids=text_ids,
|
| 410 |
+
img_ids=latent_image_ids,
|
| 411 |
+
joint_attention_kwargs={**self.joint_attention_kwargs, 'neg_mode': True},
|
| 412 |
+
return_dict=False,
|
| 413 |
+
)[0]
|
| 414 |
+
|
| 415 |
+
if image_cfg_scale > 0:
|
| 416 |
+
image_noise_pred = self.transformer(
|
| 417 |
+
hidden_states=latents,
|
| 418 |
+
timestep=timestep / 1000,
|
| 419 |
+
guidance=guidance,
|
| 420 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 421 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 422 |
+
txt_ids=text_ids,
|
| 423 |
+
img_ids=latent_image_ids,
|
| 424 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 425 |
+
return_dict=False,
|
| 426 |
+
)[0]
|
| 427 |
+
|
| 428 |
+
if image_cfg_scale == 0:
|
| 429 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 430 |
+
else:
|
| 431 |
+
noise_pred = normalized_guidance_image(neg_noise_pred, noise_pred, image_noise_pred, true_cfg_scale, image_cfg_scale)
|
| 432 |
+
|
| 433 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 434 |
+
latents_dtype = latents.dtype
|
| 435 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 436 |
+
|
| 437 |
+
if latents.dtype != latents_dtype:
|
| 438 |
+
if torch.backends.mps.is_available():
|
| 439 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 440 |
+
latents = latents.to(latents_dtype)
|
| 441 |
+
|
| 442 |
+
if callback_on_step_end is not None:
|
| 443 |
+
callback_kwargs = {}
|
| 444 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 445 |
+
callback_kwargs[k] = locals()[k]
|
| 446 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 447 |
+
|
| 448 |
+
latents = callback_outputs.pop("latents", latents)
|
| 449 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 450 |
+
|
| 451 |
+
# call the callback, if provided
|
| 452 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 453 |
+
progress_bar.update()
|
| 454 |
+
|
| 455 |
+
if XLA_AVAILABLE:
|
| 456 |
+
xm.mark_step()
|
| 457 |
+
|
| 458 |
+
self._current_timestep = None
|
| 459 |
+
|
| 460 |
+
if output_type == "latent":
|
| 461 |
+
image = latents
|
| 462 |
+
else:
|
| 463 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 464 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 465 |
+
image = self.vae.decode(latents, return_dict=False)
|
| 466 |
+
|
| 467 |
+
# Offload all models
|
| 468 |
+
self.maybe_free_model_hooks()
|
| 469 |
+
|
| 470 |
+
return (image,)
|
pipelines/flux_pipeline/transformer.py
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/bghira/SimpleTuner/blob/d0b5f37913a80aabdb0cac893937072dfa3e6a4b/helpers/models/flux/transformer.py#L404
|
| 2 |
+
# Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Originally licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 15 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 16 |
+
from diffusers.models.attention import FeedForward
|
| 17 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 18 |
+
from diffusers.models.embeddings import (
|
| 19 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
| 20 |
+
CombinedTimestepTextProjEmbeddings,
|
| 21 |
+
FluxPosEmbed,
|
| 22 |
+
)
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import (
|
| 26 |
+
AdaLayerNormContinuous,
|
| 27 |
+
AdaLayerNormZero,
|
| 28 |
+
AdaLayerNormZeroSingle,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.utils import (
|
| 31 |
+
USE_PEFT_BACKEND,
|
| 32 |
+
is_torch_version,
|
| 33 |
+
logging,
|
| 34 |
+
scale_lora_layers,
|
| 35 |
+
unscale_lora_layers,
|
| 36 |
+
)
|
| 37 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 38 |
+
from einops import rearrange
|
| 39 |
+
from peft.tuners.lora.layer import LoraLayer
|
| 40 |
+
|
| 41 |
+
# Import flex_attention for optimized attention with fixed masks
|
| 42 |
+
try:
|
| 43 |
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
| 44 |
+
FLEX_ATTENTION_AVAILABLE = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
FLEX_ATTENTION_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
flex_attention_func = None
|
| 51 |
+
block_mask = None
|
| 52 |
+
|
| 53 |
+
class FluxAttnProcessor2_0:
|
| 54 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
| 55 |
+
|
| 56 |
+
def __init__(self):
|
| 57 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 58 |
+
raise ImportError(
|
| 59 |
+
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 60 |
+
)
|
| 61 |
+
self.name = None
|
| 62 |
+
|
| 63 |
+
def __call__(
|
| 64 |
+
self,
|
| 65 |
+
attn: Attention,
|
| 66 |
+
hidden_states: torch.FloatTensor,
|
| 67 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 68 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 69 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 70 |
+
shared_attn: bool = False, num=2,
|
| 71 |
+
scale: float = 1.0,
|
| 72 |
+
timestep: float = 0,
|
| 73 |
+
neg_mode: bool = False,
|
| 74 |
+
) -> torch.FloatTensor:
|
| 75 |
+
|
| 76 |
+
batch_size, _, _ = (
|
| 77 |
+
hidden_states.shape
|
| 78 |
+
if encoder_hidden_states is None
|
| 79 |
+
else encoder_hidden_states.shape
|
| 80 |
+
)
|
| 81 |
+
end_of_hidden_states = hidden_states.shape[1]
|
| 82 |
+
text_seq = 512
|
| 83 |
+
mask = None
|
| 84 |
+
query = attn.to_q(hidden_states)
|
| 85 |
+
key = attn.to_k(hidden_states)
|
| 86 |
+
value = attn.to_v(hidden_states)
|
| 87 |
+
|
| 88 |
+
inner_dim = key.shape[-1]
|
| 89 |
+
head_dim = inner_dim // attn.heads
|
| 90 |
+
|
| 91 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 92 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 93 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
if attn.norm_q is not None:
|
| 96 |
+
query = attn.norm_q(query)
|
| 97 |
+
if attn.norm_k is not None:
|
| 98 |
+
key = attn.norm_k(key)
|
| 99 |
+
|
| 100 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
| 101 |
+
if encoder_hidden_states is not None:
|
| 102 |
+
# `context` projections.
|
| 103 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 104 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 105 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 106 |
+
|
| 107 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 108 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 109 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 110 |
+
|
| 111 |
+
if attn.norm_added_q is not None:
|
| 112 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 113 |
+
if attn.norm_added_k is not None:
|
| 114 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 115 |
+
|
| 116 |
+
# attention
|
| 117 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 118 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 119 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 120 |
+
|
| 121 |
+
if image_rotary_emb is not None:
|
| 122 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 123 |
+
query = apply_rotary_emb(query, image_rotary_emb).to(hidden_states.dtype)
|
| 124 |
+
key = apply_rotary_emb(key, image_rotary_emb).to(hidden_states.dtype)
|
| 125 |
+
|
| 126 |
+
if neg_mode and FLEX_ATTENTION_AVAILABLE:
|
| 127 |
+
# Apply flex_attention with the block mask
|
| 128 |
+
global block_mask
|
| 129 |
+
need_new_mask = block_mask is None
|
| 130 |
+
|
| 131 |
+
if need_new_mask:
|
| 132 |
+
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 133 |
+
seq_len = query.shape[2]
|
| 134 |
+
|
| 135 |
+
def block_diagonal_mask(b, h, q_idx, kv_idx):
|
| 136 |
+
text_offset = 512
|
| 137 |
+
# Text tokens (first 512) can attend to everything
|
| 138 |
+
# Use tensor operations instead of if statements
|
| 139 |
+
is_text = (q_idx < text_offset) | (kv_idx < text_offset)
|
| 140 |
+
|
| 141 |
+
# For spatial tokens, compute which block they belong to
|
| 142 |
+
q_spatial = q_idx - text_offset
|
| 143 |
+
kv_spatial = kv_idx - text_offset
|
| 144 |
+
|
| 145 |
+
# Determine block indices
|
| 146 |
+
q_block = (q_spatial // res) % num
|
| 147 |
+
kv_block = (kv_spatial // res) % num
|
| 148 |
+
|
| 149 |
+
# Only attend within the same block
|
| 150 |
+
same_block = (q_block == kv_block)
|
| 151 |
+
|
| 152 |
+
# Return: text can attend to everything OR same block
|
| 153 |
+
return is_text | same_block
|
| 154 |
+
|
| 155 |
+
# Create block mask for efficiency
|
| 156 |
+
block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
|
| 157 |
+
Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
|
| 158 |
+
|
| 159 |
+
hidden_states = flex_attention(query, key, value, block_mask=block_mask)
|
| 160 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 161 |
+
elif neg_mode:
|
| 162 |
+
# Fallback to original implementation if flex_attention is not available
|
| 163 |
+
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 164 |
+
hw = res*res
|
| 165 |
+
mask_ = torch.zeros(1, res, num*res, res, num*res).to(query.device)
|
| 166 |
+
for i in range(num):
|
| 167 |
+
mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
|
| 168 |
+
mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
|
| 169 |
+
mask = torch.ones(1, num*hw + 512, num*hw + 512, device=query.device, dtype=query.dtype)
|
| 170 |
+
mask[:, 512:, 512:] = mask_
|
| 171 |
+
mask = mask.bool()
|
| 172 |
+
mask = rearrange(mask.unsqueeze(0).expand(attn.heads, -1, -1, -1), "nh b ... -> b nh ...")
|
| 173 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
|
| 174 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 175 |
+
else:
|
| 176 |
+
# No masking needed
|
| 177 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
|
| 178 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 179 |
+
|
| 180 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 181 |
+
|
| 182 |
+
if encoder_hidden_states is not None:
|
| 183 |
+
encoder_hidden_states, hidden_states = (
|
| 184 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 185 |
+
hidden_states[:, encoder_hidden_states.shape[1]:],
|
| 186 |
+
)
|
| 187 |
+
hidden_states = hidden_states[:, :end_of_hidden_states]
|
| 188 |
+
|
| 189 |
+
# linear proj
|
| 190 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 191 |
+
# dropout
|
| 192 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 193 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 194 |
+
return hidden_states, encoder_hidden_states
|
| 195 |
+
else:
|
| 196 |
+
return hidden_states[:, :end_of_hidden_states]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def expand_flux_attention_mask(
|
| 200 |
+
hidden_states: torch.Tensor,
|
| 201 |
+
attn_mask: torch.Tensor,
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
"""
|
| 204 |
+
Expand a mask so that the image is included.
|
| 205 |
+
"""
|
| 206 |
+
bsz = attn_mask.shape[0]
|
| 207 |
+
assert bsz == hidden_states.shape[0]
|
| 208 |
+
residual_seq_len = hidden_states.shape[1]
|
| 209 |
+
mask_seq_len = attn_mask.shape[1]
|
| 210 |
+
|
| 211 |
+
expanded_mask = torch.ones(bsz, residual_seq_len)
|
| 212 |
+
expanded_mask[:, :mask_seq_len] = attn_mask
|
| 213 |
+
|
| 214 |
+
return expanded_mask
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@maybe_allow_in_graph
|
| 218 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 219 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 222 |
+
|
| 223 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 224 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 225 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 226 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 227 |
+
|
| 228 |
+
processor = FluxAttnProcessor2_0()
|
| 229 |
+
# processor = FluxSingleAttnProcessor3_0()
|
| 230 |
+
|
| 231 |
+
self.attn = Attention(
|
| 232 |
+
query_dim=dim,
|
| 233 |
+
cross_attention_dim=None,
|
| 234 |
+
dim_head=attention_head_dim,
|
| 235 |
+
heads=num_attention_heads,
|
| 236 |
+
out_dim=dim,
|
| 237 |
+
bias=True,
|
| 238 |
+
processor=processor,
|
| 239 |
+
qk_norm="rms_norm",
|
| 240 |
+
eps=1e-6,
|
| 241 |
+
pre_only=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.FloatTensor,
|
| 247 |
+
temb: torch.FloatTensor,
|
| 248 |
+
image_rotary_emb=None,
|
| 249 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 250 |
+
):
|
| 251 |
+
dtype = hidden_states.dtype
|
| 252 |
+
residual = hidden_states
|
| 253 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 254 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 255 |
+
|
| 256 |
+
attn_output = self.attn(
|
| 257 |
+
hidden_states=norm_hidden_states.to(dtype),
|
| 258 |
+
image_rotary_emb=image_rotary_emb,
|
| 259 |
+
**joint_attention_kwargs,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 263 |
+
gate = gate.unsqueeze(1)
|
| 264 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 265 |
+
hidden_states = residual + hidden_states
|
| 266 |
+
|
| 267 |
+
return hidden_states
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@maybe_allow_in_graph
|
| 271 |
+
class FluxTransformerBlock(nn.Module):
|
| 272 |
+
def __init__(
|
| 273 |
+
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
|
| 277 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 278 |
+
|
| 279 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 280 |
+
|
| 281 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 282 |
+
processor = FluxAttnProcessor2_0()
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError(
|
| 285 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 286 |
+
)
|
| 287 |
+
self.attn = Attention(
|
| 288 |
+
query_dim=dim,
|
| 289 |
+
cross_attention_dim=None,
|
| 290 |
+
added_kv_proj_dim=dim,
|
| 291 |
+
dim_head=attention_head_dim,
|
| 292 |
+
heads=num_attention_heads,
|
| 293 |
+
out_dim=dim,
|
| 294 |
+
context_pre_only=False,
|
| 295 |
+
bias=True,
|
| 296 |
+
processor=processor,
|
| 297 |
+
qk_norm=qk_norm,
|
| 298 |
+
eps=eps,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 302 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 303 |
+
|
| 304 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 305 |
+
self.ff_context = FeedForward(
|
| 306 |
+
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# let chunk size default to None
|
| 310 |
+
self._chunk_size = None
|
| 311 |
+
self._chunk_dim = 0
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
hidden_states: torch.FloatTensor,
|
| 316 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 317 |
+
temb: torch.FloatTensor,
|
| 318 |
+
image_rotary_emb=None,
|
| 319 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
| 320 |
+
):
|
| 321 |
+
dtype = hidden_states.dtype
|
| 322 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 323 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (self.norm1_context(encoder_hidden_states, emb=temb))
|
| 324 |
+
|
| 325 |
+
# Attention.
|
| 326 |
+
attn_output, context_attn_output = self.attn(
|
| 327 |
+
hidden_states=norm_hidden_states.to(dtype),
|
| 328 |
+
encoder_hidden_states=norm_encoder_hidden_states.to(dtype),
|
| 329 |
+
image_rotary_emb=image_rotary_emb,
|
| 330 |
+
**joint_attention_kwargs,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Process attention outputs for the `hidden_states`.
|
| 334 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 335 |
+
hidden_states = hidden_states + attn_output
|
| 336 |
+
|
| 337 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 338 |
+
norm_hidden_states = (norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None])
|
| 339 |
+
|
| 340 |
+
ff_output = self.ff(norm_hidden_states)
|
| 341 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 342 |
+
|
| 343 |
+
hidden_states = hidden_states + ff_output
|
| 344 |
+
|
| 345 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 346 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 347 |
+
|
| 348 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 349 |
+
norm_encoder_hidden_states = (
|
| 350 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
| 351 |
+
+ c_shift_mlp[:, None]
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 355 |
+
encoder_hidden_states = (
|
| 356 |
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
return encoder_hidden_states, hidden_states
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@contextmanager
|
| 363 |
+
def set_adapter_scale(model, alpha):
|
| 364 |
+
original_scaling = {}
|
| 365 |
+
for module in model.modules():
|
| 366 |
+
if isinstance(module, LoraLayer):
|
| 367 |
+
original_scaling[module] = module.scaling.copy()
|
| 368 |
+
module.scaling = {k: v * alpha for k, v in module.scaling.items()}
|
| 369 |
+
|
| 370 |
+
# check whether scaling is prohibited on model
|
| 371 |
+
# the original scaling dictionary should be empty
|
| 372 |
+
# if there were no lora layers
|
| 373 |
+
if not original_scaling:
|
| 374 |
+
raise ValueError("scaling is only supported for models with `LoraLayer`s")
|
| 375 |
+
try:
|
| 376 |
+
yield
|
| 377 |
+
|
| 378 |
+
finally:
|
| 379 |
+
# restore original scaling values after exiting the context
|
| 380 |
+
for module, scaling in original_scaling.items():
|
| 381 |
+
module.scaling = scaling
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class FluxTransformer2DModelWithMasking(
|
| 385 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
|
| 386 |
+
):
|
| 387 |
+
"""
|
| 388 |
+
The Transformer model introduced in Flux.
|
| 389 |
+
|
| 390 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 391 |
+
|
| 392 |
+
Parameters:
|
| 393 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
| 394 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
| 395 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
| 396 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
| 397 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
| 398 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
| 399 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 400 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
| 401 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
_supports_gradient_checkpointing = True
|
| 405 |
+
|
| 406 |
+
@register_to_config
|
| 407 |
+
def __init__(
|
| 408 |
+
self,
|
| 409 |
+
patch_size: int = 1,
|
| 410 |
+
in_channels: int = 64,
|
| 411 |
+
num_layers: int = 19,
|
| 412 |
+
num_single_layers: int = 38,
|
| 413 |
+
attention_head_dim: int = 128,
|
| 414 |
+
num_attention_heads: int = 24,
|
| 415 |
+
joint_attention_dim: int = 4096,
|
| 416 |
+
pooled_projection_dim: int = 768,
|
| 417 |
+
guidance_embeds: bool = False,
|
| 418 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
| 419 |
+
##
|
| 420 |
+
):
|
| 421 |
+
super().__init__()
|
| 422 |
+
self.out_channels = in_channels
|
| 423 |
+
self.inner_dim = (
|
| 424 |
+
self.config.num_attention_heads * self.config.attention_head_dim
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 428 |
+
text_time_guidance_cls = (
|
| 429 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
| 430 |
+
if guidance_embeds
|
| 431 |
+
else CombinedTimestepTextProjEmbeddings
|
| 432 |
+
)
|
| 433 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 434 |
+
embedding_dim=self.inner_dim,
|
| 435 |
+
pooled_projection_dim=self.config.pooled_projection_dim,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
self.context_embedder = nn.Linear(
|
| 439 |
+
self.config.joint_attention_dim, self.inner_dim
|
| 440 |
+
)
|
| 441 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
| 442 |
+
|
| 443 |
+
self.transformer_blocks = nn.ModuleList(
|
| 444 |
+
[
|
| 445 |
+
FluxTransformerBlock(
|
| 446 |
+
dim=self.inner_dim,
|
| 447 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 448 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 449 |
+
)
|
| 450 |
+
for i in range(self.config.num_layers)
|
| 451 |
+
]
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 455 |
+
[
|
| 456 |
+
FluxSingleTransformerBlock(
|
| 457 |
+
dim=self.inner_dim,
|
| 458 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 459 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 460 |
+
)
|
| 461 |
+
for i in range(self.config.num_single_layers)
|
| 462 |
+
]
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 466 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
| 467 |
+
)
|
| 468 |
+
self.proj_out = nn.Linear(
|
| 469 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
self.gradient_checkpointing = False
|
| 473 |
+
|
| 474 |
+
@property
|
| 475 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 476 |
+
r"""
|
| 477 |
+
Returns:
|
| 478 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 479 |
+
indexed by its weight name.
|
| 480 |
+
"""
|
| 481 |
+
# set recursively
|
| 482 |
+
processors = {}
|
| 483 |
+
|
| 484 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 485 |
+
if hasattr(module, "get_processor"):
|
| 486 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 487 |
+
|
| 488 |
+
for sub_name, child in module.named_children():
|
| 489 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 490 |
+
|
| 491 |
+
return processors
|
| 492 |
+
|
| 493 |
+
for name, module in self.named_children():
|
| 494 |
+
fn_recursive_add_processors(name, module, processors)
|
| 495 |
+
|
| 496 |
+
return processors
|
| 497 |
+
|
| 498 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 499 |
+
r"""
|
| 500 |
+
Sets the attention processor to use to compute attention.
|
| 501 |
+
|
| 502 |
+
Parameters:
|
| 503 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 504 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 505 |
+
for **all** `Attention` layers.
|
| 506 |
+
|
| 507 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 508 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 509 |
+
|
| 510 |
+
"""
|
| 511 |
+
count = len(self.attn_processors.keys())
|
| 512 |
+
|
| 513 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 514 |
+
raise ValueError(
|
| 515 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 516 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 520 |
+
if hasattr(module, "set_processor"):
|
| 521 |
+
if not isinstance(processor, dict):
|
| 522 |
+
module.set_processor(processor)
|
| 523 |
+
else:
|
| 524 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 525 |
+
|
| 526 |
+
for sub_name, child in module.named_children():
|
| 527 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 528 |
+
|
| 529 |
+
for name, module in self.named_children():
|
| 530 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 531 |
+
|
| 532 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 533 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 534 |
+
module.gradient_checkpointing = value
|
| 535 |
+
|
| 536 |
+
def forward(
|
| 537 |
+
self,
|
| 538 |
+
hidden_states: torch.Tensor,
|
| 539 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 540 |
+
pooled_projections: torch.Tensor = None,
|
| 541 |
+
timestep: torch.LongTensor = None,
|
| 542 |
+
img_ids: torch.Tensor = None,
|
| 543 |
+
txt_ids: torch.Tensor = None,
|
| 544 |
+
guidance: torch.Tensor = None,
|
| 545 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 546 |
+
return_dict: bool = True,
|
| 547 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
| 548 |
+
"""
|
| 549 |
+
The [`FluxTransformer2DModelWithMasking`] forward method.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
| 553 |
+
Input `hidden_states`.
|
| 554 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
| 555 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 556 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 557 |
+
from the embeddings of input conditions.
|
| 558 |
+
timestep ( `torch.LongTensor`):
|
| 559 |
+
Used to indicate denoising step.
|
| 560 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 561 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 562 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 563 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 564 |
+
`self.processor` in
|
| 565 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 566 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 567 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 568 |
+
tuple.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 572 |
+
`tuple` where the first element is the sample tensor.
|
| 573 |
+
"""
|
| 574 |
+
if joint_attention_kwargs is not None:
|
| 575 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 576 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 577 |
+
else:
|
| 578 |
+
lora_scale = 1.0
|
| 579 |
+
|
| 580 |
+
if USE_PEFT_BACKEND:
|
| 581 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 582 |
+
scale_lora_layers(self, lora_scale)
|
| 583 |
+
else:
|
| 584 |
+
if (
|
| 585 |
+
joint_attention_kwargs is not None
|
| 586 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
| 587 |
+
):
|
| 588 |
+
logger.warning(
|
| 589 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 590 |
+
)
|
| 591 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 592 |
+
|
| 593 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 594 |
+
if guidance is not None:
|
| 595 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 596 |
+
else:
|
| 597 |
+
guidance = None
|
| 598 |
+
temb = (
|
| 599 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 600 |
+
if guidance is None
|
| 601 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 602 |
+
)
|
| 603 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 604 |
+
|
| 605 |
+
if txt_ids.ndim == 3:
|
| 606 |
+
txt_ids = txt_ids[0]
|
| 607 |
+
if img_ids.ndim == 3:
|
| 608 |
+
img_ids = img_ids[0]
|
| 609 |
+
|
| 610 |
+
ids = torch.cat((txt_ids, img_ids), dim=0).to(hidden_states.dtype)
|
| 611 |
+
|
| 612 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 613 |
+
|
| 614 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 615 |
+
if self.training and self.gradient_checkpointing:
|
| 616 |
+
|
| 617 |
+
def create_custom_forward(module, return_dict=None):
|
| 618 |
+
def custom_forward(*inputs):
|
| 619 |
+
if return_dict is not None:
|
| 620 |
+
return module(*inputs, return_dict=return_dict)
|
| 621 |
+
else:
|
| 622 |
+
return module(*inputs)
|
| 623 |
+
|
| 624 |
+
return custom_forward
|
| 625 |
+
|
| 626 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 627 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 628 |
+
)
|
| 629 |
+
encoder_hidden_states, hidden_states = (
|
| 630 |
+
torch.utils.checkpoint.checkpoint(
|
| 631 |
+
create_custom_forward(block),
|
| 632 |
+
hidden_states,
|
| 633 |
+
encoder_hidden_states,
|
| 634 |
+
temb,
|
| 635 |
+
image_rotary_emb,
|
| 636 |
+
joint_attention_kwargs,
|
| 637 |
+
**ckpt_kwargs,
|
| 638 |
+
)
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
else:
|
| 642 |
+
encoder_hidden_states, hidden_states = block(
|
| 643 |
+
hidden_states=hidden_states,
|
| 644 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 645 |
+
temb=temb,
|
| 646 |
+
image_rotary_emb=image_rotary_emb,
|
| 647 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
# Flux places the text tokens in front of the image tokens in the
|
| 651 |
+
# sequence.
|
| 652 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 653 |
+
|
| 654 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 655 |
+
if self.training and self.gradient_checkpointing:
|
| 656 |
+
|
| 657 |
+
def create_custom_forward(module, return_dict=None):
|
| 658 |
+
def custom_forward(*inputs):
|
| 659 |
+
if return_dict is not None:
|
| 660 |
+
return module(*inputs, return_dict=return_dict)
|
| 661 |
+
else:
|
| 662 |
+
return module(*inputs)
|
| 663 |
+
|
| 664 |
+
return custom_forward
|
| 665 |
+
|
| 666 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 667 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 668 |
+
)
|
| 669 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 670 |
+
create_custom_forward(block),
|
| 671 |
+
hidden_states,
|
| 672 |
+
temb,
|
| 673 |
+
image_rotary_emb,
|
| 674 |
+
joint_attention_kwargs,
|
| 675 |
+
**ckpt_kwargs,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
else:
|
| 679 |
+
hidden_states = block(
|
| 680 |
+
hidden_states=hidden_states,
|
| 681 |
+
temb=temb,
|
| 682 |
+
image_rotary_emb=image_rotary_emb,
|
| 683 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:, ...]
|
| 687 |
+
|
| 688 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 689 |
+
output = self.proj_out(hidden_states)
|
| 690 |
+
|
| 691 |
+
if USE_PEFT_BACKEND:
|
| 692 |
+
# remove `lora_scale` from each PEFT layer
|
| 693 |
+
unscale_lora_layers(self, lora_scale)
|
| 694 |
+
|
| 695 |
+
if not return_dict:
|
| 696 |
+
return (output,)
|
| 697 |
+
|
| 698 |
+
return Transformer2DModelOutput(sample=output)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
if __name__ == "__main__":
|
| 702 |
+
dtype = torch.bfloat16
|
| 703 |
+
bsz = 2
|
| 704 |
+
img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
|
| 705 |
+
timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
|
| 706 |
+
pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
|
| 707 |
+
text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
|
| 708 |
+
attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
|
| 709 |
+
"cuda", dtype=dtype
|
| 710 |
+
) # Last 128 positions are masked
|
| 711 |
+
|
| 712 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 713 |
+
latents = latents.view(
|
| 714 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
| 715 |
+
)
|
| 716 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 717 |
+
latents = latents.reshape(
|
| 718 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
return latents
|
| 722 |
+
|
| 723 |
+
def _prepare_latent_image_ids(
|
| 724 |
+
batch_size, height, width, device="cuda", dtype=dtype
|
| 725 |
+
):
|
| 726 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 727 |
+
latent_image_ids[..., 1] = (
|
| 728 |
+
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 729 |
+
)
|
| 730 |
+
latent_image_ids[..., 2] = (
|
| 731 |
+
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
| 735 |
+
latent_image_ids.shape
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 739 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 740 |
+
batch_size,
|
| 741 |
+
latent_image_id_height * latent_image_id_width,
|
| 742 |
+
latent_image_id_channels,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 746 |
+
|
| 747 |
+
txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
|
| 748 |
+
|
| 749 |
+
vae_scale_factor = 16
|
| 750 |
+
height = 2 * (int(512) // vae_scale_factor)
|
| 751 |
+
width = 2 * (int(512) // vae_scale_factor)
|
| 752 |
+
img_ids = _prepare_latent_image_ids(bsz, height, width)
|
| 753 |
+
img = _pack_latents(img, img.shape[0], 16, height, width)
|
| 754 |
+
|
| 755 |
+
# Gotta go fast
|
| 756 |
+
transformer = FluxTransformer2DModelWithMasking.from_config(
|
| 757 |
+
{
|
| 758 |
+
"attention_head_dim": 128,
|
| 759 |
+
"guidance_embeds": True,
|
| 760 |
+
"in_channels": 64,
|
| 761 |
+
"joint_attention_dim": 4096,
|
| 762 |
+
"num_attention_heads": 24,
|
| 763 |
+
"num_layers": 4,
|
| 764 |
+
"num_single_layers": 8,
|
| 765 |
+
"patch_size": 1,
|
| 766 |
+
"pooled_projection_dim": 768,
|
| 767 |
+
}
|
| 768 |
+
).to("cuda", dtype=dtype)
|
| 769 |
+
|
| 770 |
+
guidance = torch.tensor([2.0], device="cuda")
|
| 771 |
+
guidance = guidance.expand(bsz)
|
| 772 |
+
|
| 773 |
+
with torch.no_grad():
|
| 774 |
+
no_mask = transformer(
|
| 775 |
+
img,
|
| 776 |
+
encoder_hidden_states=text,
|
| 777 |
+
pooled_projections=pooled,
|
| 778 |
+
timestep=timestep,
|
| 779 |
+
img_ids=img_ids,
|
| 780 |
+
txt_ids=txt_ids,
|
| 781 |
+
guidance=guidance,
|
| 782 |
+
)
|
| 783 |
+
mask = transformer(
|
| 784 |
+
img,
|
| 785 |
+
encoder_hidden_states=text,
|
| 786 |
+
pooled_projections=pooled,
|
| 787 |
+
timestep=timestep,
|
| 788 |
+
img_ids=img_ids,
|
| 789 |
+
txt_ids=txt_ids,
|
| 790 |
+
guidance=guidance,
|
| 791 |
+
attention_mask=attn_mask,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
assert torch.allclose(no_mask.sample, mask.sample) is False
|
| 795 |
+
print("Attention masking test ran OK. Differences in output were detected.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
peft
|
| 5 |
+
einops
|
| 6 |
+
numpy
|
| 7 |
+
Pillow
|
| 8 |
+
sentencepiece
|
| 9 |
+
huggingface_hub
|