Jitesh Dhamaniya
commited on
Commit
·
8c46fdc
1
Parent(s):
f0357cb
Init
Browse files- README.md +98 -0
- config.json +20 -0
- handler.py +57 -0
- main.py +59 -0
- transformer_flux.py +525 -0
README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: flux-1-dev-non-commercial-license
|
| 4 |
+
license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md
|
| 5 |
+
language:
|
| 6 |
+
- en
|
| 7 |
+
base_model: black-forest-labs/FLUX.1-dev
|
| 8 |
+
library_name: diffusers
|
| 9 |
+
tags:
|
| 10 |
+
- Text-to-Image
|
| 11 |
+
- ControlNet
|
| 12 |
+
- Inpainting
|
| 13 |
+
- FLUX
|
| 14 |
+
- Stable Diffusion
|
| 15 |
+
---
|
| 16 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 17 |
+
<img src="images/alibaba.png" alt="alibaba" style="width: 20%; height: auto; margin-right: 5%;">
|
| 18 |
+
<img src="images/alimama.png" alt="alimama" style="width: 20%; height: auto;">
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
This repository provides a Inpainting ControlNet checkpoint for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) model released by AlimamaCreative Team.
|
| 22 |
+
|
| 23 |
+
## Beta Version Now Available
|
| 24 |
+
|
| 25 |
+
We are excited to announce the release of our beta version, which brings further enhancements to our inpainting capabilities:
|
| 26 |
+
|
| 27 |
+
To access and test the beta version, please visit our [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta). We welcome your feedback and suggestions as we continue to refine and improve our model.
|
| 28 |
+
|
| 29 |
+
## News
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
🎉 Thanks to @comfyanonymous,ComfyUI now supports inference for Alimama inpainting ControlNet. Workflow can be downloaded from [here](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/alimama-flux-controlnet-inpaint.json).
|
| 33 |
+
|
| 34 |
+
ComfyUI Usage Tips:
|
| 35 |
+
|
| 36 |
+
* Using the `t5xxl-FP16` and `flux1-dev-fp8` models for 28-step inference, the GPU memory usage is 27GB. The inference time with `cfg=3.5` is 27 seconds, while without `cfg=1` it is 15 seconds. `Hyper-FLUX-lora` can be used to accelerate inference.
|
| 37 |
+
* You can try adjusting(lower) the parameters `control-strength`, `control-end-percent`, and `cfg` to achieve better results.
|
| 38 |
+
* The following example uses `control-strength` = 0.9 & `control-end-percent` = 1.0 & `cfg` = 3.5
|
| 39 |
+
|
| 40 |
+
| Input | Output | Prompt |
|
| 41 |
+
|------------------------------|------------------------------|-------------|
|
| 42 |
+
|  |  | <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, <span style="color:red; font-weight:bold;">Elon Musk</span>, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. |
|
| 43 |
+
|  |  | <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with <span style="color:red; font-weight:bold;">a cat</span> on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. </i></small>|
|
| 44 |
+
|  |  | <small><i>A woman with blonde hair is sitting on a table wearing a <span style="color:red; font-weight:bold;">red and white long dress</span>. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. </i></small>|
|
| 45 |
+
|  |  | <small><i>The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a <span style="color:red; font-weight:bold;">red pencil</span> in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits. </i></small>|
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Model Cards
|
| 49 |
+
|
| 50 |
+
* The model was trained on 12M laion2B and internal source images at resolution 768x768. The inference performs best at this size, with other sizes yielding suboptimal results.
|
| 51 |
+
|
| 52 |
+
* The recommended controlnet_conditioning_scale is 0.9 - 0.95.
|
| 53 |
+
|
| 54 |
+
* **Please note: This is only the alpha version during the training process. We will release an updated version when we feel ready.**
|
| 55 |
+
|
| 56 |
+
## Showcase
|
| 57 |
+
|
| 58 |
+

|
| 59 |
+

|
| 60 |
+

|
| 61 |
+
|
| 62 |
+
## Comparison with SDXL-Inpainting
|
| 63 |
+
|
| 64 |
+
Compared with [SDXL-Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
|
| 65 |
+
|
| 66 |
+
From left to right: Input image | Masked image | SDXL inpainting | Ours
|
| 67 |
+
|
| 68 |
+

|
| 69 |
+
<small><i>*The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits.*</i></small>
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
<small><i>A woman with blonde hair is sitting on a table wearing a blue and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene.</i></small>
|
| 73 |
+
|
| 74 |
+

|
| 75 |
+
<small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cup of coffee on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. There are several cups and a cake on the table in the background. The man sitting at the table appears to be typing on the laptop.</i></small>
|
| 76 |
+
|
| 77 |
+

|
| 78 |
+
<small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Naruto, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal.</i></small>
|
| 79 |
+
|
| 80 |
+
## Using with Diffusers
|
| 81 |
+
|
| 82 |
+
Step1: install diffusers
|
| 83 |
+
``` Shell
|
| 84 |
+
pip install diffusers==0.30.2
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Step2: clone repo from github
|
| 88 |
+
``` Shell
|
| 89 |
+
git clone https://github.com/alimama-creative/FLUX-Controlnet-Inpainting.git
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Step3: modify the image_path, mask_path, prompt and run
|
| 93 |
+
``` Shell
|
| 94 |
+
python main.py
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## LICENSE
|
| 98 |
+
Our weights fall under the [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) Non-Commercial License.
|
config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FluxControlNetModel",
|
| 3 |
+
"_diffusers_version": "0.30.2",
|
| 4 |
+
"_name_or_path": "/data/oss_bucket_0/linjinpeng.ljp/exp_flux/r768_bs2_ga6_adamw_lr5e-6_bf16_cfg3.5_sin0_dou6_s1/checkpoint-77500",
|
| 5 |
+
"attention_head_dim": 128,
|
| 6 |
+
"axes_dims_rope": [
|
| 7 |
+
16,
|
| 8 |
+
56,
|
| 9 |
+
56
|
| 10 |
+
],
|
| 11 |
+
"extra_condition_channels": 4,
|
| 12 |
+
"guidance_embeds": true,
|
| 13 |
+
"in_channels": 64,
|
| 14 |
+
"joint_attention_dim": 4096,
|
| 15 |
+
"num_attention_heads": 24,
|
| 16 |
+
"num_layers": 6,
|
| 17 |
+
"num_single_layers": 0,
|
| 18 |
+
"patch_size": 1,
|
| 19 |
+
"pooled_projection_dim": 768
|
| 20 |
+
}
|
handler.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import DiffusionPipeline
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import base64
|
| 7 |
+
|
| 8 |
+
# Load the model once (caching for efficiency)
|
| 9 |
+
MODEL_ID = "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"
|
| 10 |
+
CONTROLNET_MODEL = "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"
|
| 11 |
+
TRANSFORMER_MODEL = "black-forest-labs/FLUX.1-dev"
|
| 12 |
+
|
| 13 |
+
controlnet = DiffusionPipeline.from_pretrained(CONTROLNET_MODEL, torch_dtype=torch.bfloat16)
|
| 14 |
+
transformer = DiffusionPipeline.from_pretrained(TRANSFORMER_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16)
|
| 15 |
+
|
| 16 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 17 |
+
MODEL_ID,
|
| 18 |
+
controlnet=controlnet,
|
| 19 |
+
transformer=transformer,
|
| 20 |
+
torch_dtype=torch.bfloat16
|
| 21 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# Function to handle inference
|
| 24 |
+
def handle(inputs, context):
|
| 25 |
+
try:
|
| 26 |
+
# Parse inputs
|
| 27 |
+
prompt = inputs.get("prompt", "default prompt text")
|
| 28 |
+
control_image_base64 = inputs.get("control_image")
|
| 29 |
+
mask_image_base64 = inputs.get("mask_image")
|
| 30 |
+
num_inference_steps = inputs.get("num_inference_steps", 28)
|
| 31 |
+
guidance_scale = inputs.get("guidance_scale", 3.5)
|
| 32 |
+
controlnet_conditioning_scale = inputs.get("controlnet_conditioning_scale", 0.9)
|
| 33 |
+
|
| 34 |
+
# Convert Base64 images to PIL format
|
| 35 |
+
control_image = Image.open(BytesIO(base64.b64decode(control_image_base64))).convert("RGB")
|
| 36 |
+
mask_image = Image.open(BytesIO(base64.b64decode(mask_image_base64))).convert("RGB")
|
| 37 |
+
|
| 38 |
+
# Perform inference
|
| 39 |
+
result = pipeline(
|
| 40 |
+
prompt=prompt,
|
| 41 |
+
control_image=control_image,
|
| 42 |
+
control_mask=mask_image,
|
| 43 |
+
num_inference_steps=num_inference_steps,
|
| 44 |
+
guidance_scale=guidance_scale,
|
| 45 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 46 |
+
).images[0]
|
| 47 |
+
|
| 48 |
+
# Convert result to Base64 string
|
| 49 |
+
buffered = BytesIO()
|
| 50 |
+
result.save(buffered, format="PNG")
|
| 51 |
+
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 52 |
+
|
| 53 |
+
# Return the result
|
| 54 |
+
return {"status": "success", "image": result_base64}
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
return {"status": "error", "message": str(e)}
|
main.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers.utils import load_image, check_min_version
|
| 3 |
+
from controlnet_flux import FluxControlNetModel
|
| 4 |
+
from transformer_flux import FluxTransformer2DModel
|
| 5 |
+
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
check_min_version("0.30.2")
|
| 9 |
+
|
| 10 |
+
# Set image path , mask path and prompt
|
| 11 |
+
# image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png',
|
| 12 |
+
# mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg',
|
| 13 |
+
|
| 14 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 15 |
+
image_path = os.path.join(current_dir, "images_bucket.png")
|
| 16 |
+
mask_path = os.path.join(current_dir, "images_bucket_mask.jpeg")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
prompt='a person wearing a white shoe, carrying a white bucket with text "Jitesh" on it with red color'
|
| 20 |
+
|
| 21 |
+
# print(f"Checkpoint file: {checkpoint_file}")
|
| 22 |
+
|
| 23 |
+
# Build pipeline
|
| 24 |
+
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16, local_files_only=True)
|
| 25 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 26 |
+
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16, local_files_only=True
|
| 27 |
+
)
|
| 28 |
+
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
|
| 29 |
+
"black-forest-labs/FLUX.1-dev",
|
| 30 |
+
controlnet=controlnet,
|
| 31 |
+
transformer=transformer,
|
| 32 |
+
torch_dtype=torch.bfloat16
|
| 33 |
+
).to("mps")
|
| 34 |
+
pipe.transformer.to(torch.bfloat16)
|
| 35 |
+
pipe.controlnet.to(torch.bfloat16)
|
| 36 |
+
|
| 37 |
+
# Load image and mask
|
| 38 |
+
size = (768, 768)
|
| 39 |
+
image = load_image(image_path).convert("RGB").resize(size)
|
| 40 |
+
mask = load_image(mask_path).convert("RGB").resize(size)
|
| 41 |
+
generator = torch.Generator(device="mps").manual_seed(24)
|
| 42 |
+
|
| 43 |
+
# Inpaint
|
| 44 |
+
result = pipe(
|
| 45 |
+
prompt=prompt,
|
| 46 |
+
height=size[1],
|
| 47 |
+
width=size[0],
|
| 48 |
+
control_image=image,
|
| 49 |
+
control_mask=mask,
|
| 50 |
+
num_inference_steps=28,
|
| 51 |
+
generator=generator,
|
| 52 |
+
controlnet_conditioning_scale=0.9,
|
| 53 |
+
guidance_scale=3.5,
|
| 54 |
+
negative_prompt="",
|
| 55 |
+
true_guidance_scale=1.0 # default: 3.5 for alpha and 1.0 for beta
|
| 56 |
+
).images[0]
|
| 57 |
+
|
| 58 |
+
result.save('flux_inpaint.png')
|
| 59 |
+
print("Successfully inpaint image")
|
transformer_flux.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 10 |
+
from diffusers.models.attention import FeedForward
|
| 11 |
+
from diffusers.models.attention_processor import (
|
| 12 |
+
Attention,
|
| 13 |
+
FluxAttnProcessor2_0,
|
| 14 |
+
FluxSingleAttnProcessor2_0,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 17 |
+
from diffusers.models.normalization import (
|
| 18 |
+
AdaLayerNormContinuous,
|
| 19 |
+
AdaLayerNormZero,
|
| 20 |
+
AdaLayerNormZeroSingle,
|
| 21 |
+
)
|
| 22 |
+
from diffusers.utils import (
|
| 23 |
+
USE_PEFT_BACKEND,
|
| 24 |
+
is_torch_version,
|
| 25 |
+
logging,
|
| 26 |
+
scale_lora_layers,
|
| 27 |
+
unscale_lora_layers,
|
| 28 |
+
)
|
| 29 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 30 |
+
from diffusers.models.embeddings import (
|
| 31 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
| 32 |
+
CombinedTimestepTextProjEmbeddings,
|
| 33 |
+
)
|
| 34 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# YiYi to-do: refactor rope related functions/classes
|
| 41 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 42 |
+
assert dim % 2 == 0, "The dimension must be even."
|
| 43 |
+
|
| 44 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
| 45 |
+
omega = 1.0 / (theta**scale)
|
| 46 |
+
|
| 47 |
+
batch_size, seq_length = pos.shape
|
| 48 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 49 |
+
cos_out = torch.cos(out)
|
| 50 |
+
sin_out = torch.sin(out)
|
| 51 |
+
|
| 52 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
| 53 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
| 54 |
+
return out.float()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# YiYi to-do: refactor rope related functions/classes
|
| 58 |
+
class EmbedND(nn.Module):
|
| 59 |
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.dim = dim
|
| 62 |
+
self.theta = theta
|
| 63 |
+
self.axes_dim = axes_dim
|
| 64 |
+
|
| 65 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
n_axes = ids.shape[-1]
|
| 67 |
+
emb = torch.cat(
|
| 68 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 69 |
+
dim=-3,
|
| 70 |
+
)
|
| 71 |
+
return emb.unsqueeze(1)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@maybe_allow_in_graph
|
| 75 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 76 |
+
r"""
|
| 77 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
| 78 |
+
|
| 79 |
+
Reference: https://arxiv.org/abs/2403.03206
|
| 80 |
+
|
| 81 |
+
Parameters:
|
| 82 |
+
dim (`int`): The number of channels in the input and output.
|
| 83 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 84 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 85 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
| 86 |
+
processing of `context` conditions.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 92 |
+
|
| 93 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 94 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 95 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 96 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 97 |
+
|
| 98 |
+
processor = FluxSingleAttnProcessor2_0()
|
| 99 |
+
self.attn = Attention(
|
| 100 |
+
query_dim=dim,
|
| 101 |
+
cross_attention_dim=None,
|
| 102 |
+
dim_head=attention_head_dim,
|
| 103 |
+
heads=num_attention_heads,
|
| 104 |
+
out_dim=dim,
|
| 105 |
+
bias=True,
|
| 106 |
+
processor=processor,
|
| 107 |
+
qk_norm="rms_norm",
|
| 108 |
+
eps=1e-6,
|
| 109 |
+
pre_only=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
hidden_states: torch.FloatTensor,
|
| 115 |
+
temb: torch.FloatTensor,
|
| 116 |
+
image_rotary_emb=None,
|
| 117 |
+
):
|
| 118 |
+
residual = hidden_states
|
| 119 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 120 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 121 |
+
|
| 122 |
+
attn_output = self.attn(
|
| 123 |
+
hidden_states=norm_hidden_states,
|
| 124 |
+
image_rotary_emb=image_rotary_emb,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 128 |
+
gate = gate.unsqueeze(1)
|
| 129 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 130 |
+
hidden_states = residual + hidden_states
|
| 131 |
+
if hidden_states.dtype == torch.float16:
|
| 132 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 133 |
+
|
| 134 |
+
return hidden_states
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@maybe_allow_in_graph
|
| 138 |
+
class FluxTransformerBlock(nn.Module):
|
| 139 |
+
r"""
|
| 140 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
| 141 |
+
|
| 142 |
+
Reference: https://arxiv.org/abs/2403.03206
|
| 143 |
+
|
| 144 |
+
Parameters:
|
| 145 |
+
dim (`int`): The number of channels in the input and output.
|
| 146 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 147 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 148 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
| 149 |
+
processing of `context` conditions.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
|
| 154 |
+
):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 158 |
+
|
| 159 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 160 |
+
|
| 161 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 162 |
+
processor = FluxAttnProcessor2_0()
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 166 |
+
)
|
| 167 |
+
self.attn = Attention(
|
| 168 |
+
query_dim=dim,
|
| 169 |
+
cross_attention_dim=None,
|
| 170 |
+
added_kv_proj_dim=dim,
|
| 171 |
+
dim_head=attention_head_dim,
|
| 172 |
+
heads=num_attention_heads,
|
| 173 |
+
out_dim=dim,
|
| 174 |
+
context_pre_only=False,
|
| 175 |
+
bias=True,
|
| 176 |
+
processor=processor,
|
| 177 |
+
qk_norm=qk_norm,
|
| 178 |
+
eps=eps,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 182 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 183 |
+
|
| 184 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 185 |
+
self.ff_context = FeedForward(
|
| 186 |
+
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# let chunk size default to None
|
| 190 |
+
self._chunk_size = None
|
| 191 |
+
self._chunk_dim = 0
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
hidden_states: torch.FloatTensor,
|
| 196 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 197 |
+
temb: torch.FloatTensor,
|
| 198 |
+
image_rotary_emb=None,
|
| 199 |
+
):
|
| 200 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 201 |
+
hidden_states, emb=temb
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
(
|
| 205 |
+
norm_encoder_hidden_states,
|
| 206 |
+
c_gate_msa,
|
| 207 |
+
c_shift_mlp,
|
| 208 |
+
c_scale_mlp,
|
| 209 |
+
c_gate_mlp,
|
| 210 |
+
) = self.norm1_context(encoder_hidden_states, emb=temb)
|
| 211 |
+
|
| 212 |
+
# Attention.
|
| 213 |
+
attn_output, context_attn_output = self.attn(
|
| 214 |
+
hidden_states=norm_hidden_states,
|
| 215 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 216 |
+
image_rotary_emb=image_rotary_emb,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Process attention outputs for the `hidden_states`.
|
| 220 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 221 |
+
hidden_states = hidden_states + attn_output
|
| 222 |
+
|
| 223 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 224 |
+
norm_hidden_states = (
|
| 225 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
ff_output = self.ff(norm_hidden_states)
|
| 229 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 230 |
+
|
| 231 |
+
hidden_states = hidden_states + ff_output
|
| 232 |
+
|
| 233 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 234 |
+
|
| 235 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 236 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 237 |
+
|
| 238 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 239 |
+
norm_encoder_hidden_states = (
|
| 240 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
| 241 |
+
+ c_shift_mlp[:, None]
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 245 |
+
encoder_hidden_states = (
|
| 246 |
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 247 |
+
)
|
| 248 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 249 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 250 |
+
|
| 251 |
+
return encoder_hidden_states, hidden_states
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class FluxTransformer2DModel(
|
| 255 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
The Transformer model introduced in Flux.
|
| 259 |
+
|
| 260 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 261 |
+
|
| 262 |
+
Parameters:
|
| 263 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
| 264 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
| 265 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
| 266 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
| 267 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
| 268 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
| 269 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 270 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
| 271 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
_supports_gradient_checkpointing = True
|
| 275 |
+
|
| 276 |
+
@register_to_config
|
| 277 |
+
def __init__(
|
| 278 |
+
self,
|
| 279 |
+
patch_size: int = 1,
|
| 280 |
+
in_channels: int = 64,
|
| 281 |
+
num_layers: int = 19,
|
| 282 |
+
num_single_layers: int = 38,
|
| 283 |
+
attention_head_dim: int = 128,
|
| 284 |
+
num_attention_heads: int = 24,
|
| 285 |
+
joint_attention_dim: int = 4096,
|
| 286 |
+
pooled_projection_dim: int = 768,
|
| 287 |
+
guidance_embeds: bool = False,
|
| 288 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.out_channels = in_channels
|
| 292 |
+
self.inner_dim = (
|
| 293 |
+
self.config.num_attention_heads * self.config.attention_head_dim
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
self.pos_embed = EmbedND(
|
| 297 |
+
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
|
| 298 |
+
)
|
| 299 |
+
text_time_guidance_cls = (
|
| 300 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
| 301 |
+
if guidance_embeds
|
| 302 |
+
else CombinedTimestepTextProjEmbeddings
|
| 303 |
+
)
|
| 304 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 305 |
+
embedding_dim=self.inner_dim,
|
| 306 |
+
pooled_projection_dim=self.config.pooled_projection_dim,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.context_embedder = nn.Linear(
|
| 310 |
+
self.config.joint_attention_dim, self.inner_dim
|
| 311 |
+
)
|
| 312 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
| 313 |
+
|
| 314 |
+
self.transformer_blocks = nn.ModuleList(
|
| 315 |
+
[
|
| 316 |
+
FluxTransformerBlock(
|
| 317 |
+
dim=self.inner_dim,
|
| 318 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 319 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 320 |
+
)
|
| 321 |
+
for i in range(self.config.num_layers)
|
| 322 |
+
]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 326 |
+
[
|
| 327 |
+
FluxSingleTransformerBlock(
|
| 328 |
+
dim=self.inner_dim,
|
| 329 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 330 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 331 |
+
)
|
| 332 |
+
for i in range(self.config.num_single_layers)
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 337 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
| 338 |
+
)
|
| 339 |
+
self.proj_out = nn.Linear(
|
| 340 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
self.gradient_checkpointing = False
|
| 344 |
+
|
| 345 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 346 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 347 |
+
module.gradient_checkpointing = value
|
| 348 |
+
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
hidden_states: torch.Tensor,
|
| 352 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 353 |
+
pooled_projections: torch.Tensor = None,
|
| 354 |
+
timestep: torch.LongTensor = None,
|
| 355 |
+
img_ids: torch.Tensor = None,
|
| 356 |
+
txt_ids: torch.Tensor = None,
|
| 357 |
+
guidance: torch.Tensor = None,
|
| 358 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 359 |
+
controlnet_block_samples=None,
|
| 360 |
+
controlnet_single_block_samples=None,
|
| 361 |
+
return_dict: bool = True,
|
| 362 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
| 363 |
+
"""
|
| 364 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
| 368 |
+
Input `hidden_states`.
|
| 369 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
| 370 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 371 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 372 |
+
from the embeddings of input conditions.
|
| 373 |
+
timestep ( `torch.LongTensor`):
|
| 374 |
+
Used to indicate denoising step.
|
| 375 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 376 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 377 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 378 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 379 |
+
`self.processor` in
|
| 380 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 381 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 382 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 383 |
+
tuple.
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 387 |
+
`tuple` where the first element is the sample tensor.
|
| 388 |
+
"""
|
| 389 |
+
if joint_attention_kwargs is not None:
|
| 390 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 391 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 392 |
+
else:
|
| 393 |
+
lora_scale = 1.0
|
| 394 |
+
|
| 395 |
+
if USE_PEFT_BACKEND:
|
| 396 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 397 |
+
scale_lora_layers(self, lora_scale)
|
| 398 |
+
else:
|
| 399 |
+
if (
|
| 400 |
+
joint_attention_kwargs is not None
|
| 401 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
| 402 |
+
):
|
| 403 |
+
logger.warning(
|
| 404 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 405 |
+
)
|
| 406 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 407 |
+
|
| 408 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 409 |
+
if guidance is not None:
|
| 410 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 411 |
+
else:
|
| 412 |
+
guidance = None
|
| 413 |
+
temb = (
|
| 414 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 415 |
+
if guidance is None
|
| 416 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 417 |
+
)
|
| 418 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 419 |
+
|
| 420 |
+
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
| 421 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 422 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 423 |
+
|
| 424 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 425 |
+
if self.training and self.gradient_checkpointing:
|
| 426 |
+
|
| 427 |
+
def create_custom_forward(module, return_dict=None):
|
| 428 |
+
def custom_forward(*inputs):
|
| 429 |
+
if return_dict is not None:
|
| 430 |
+
return module(*inputs, return_dict=return_dict)
|
| 431 |
+
else:
|
| 432 |
+
return module(*inputs)
|
| 433 |
+
|
| 434 |
+
return custom_forward
|
| 435 |
+
|
| 436 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 437 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 438 |
+
)
|
| 439 |
+
(
|
| 440 |
+
encoder_hidden_states,
|
| 441 |
+
hidden_states,
|
| 442 |
+
) = torch.utils.checkpoint.checkpoint(
|
| 443 |
+
create_custom_forward(block),
|
| 444 |
+
hidden_states,
|
| 445 |
+
encoder_hidden_states,
|
| 446 |
+
temb,
|
| 447 |
+
image_rotary_emb,
|
| 448 |
+
**ckpt_kwargs,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
else:
|
| 452 |
+
encoder_hidden_states, hidden_states = block(
|
| 453 |
+
hidden_states=hidden_states,
|
| 454 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 455 |
+
temb=temb,
|
| 456 |
+
image_rotary_emb=image_rotary_emb,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# controlnet residual
|
| 460 |
+
if controlnet_block_samples is not None:
|
| 461 |
+
interval_control = len(self.transformer_blocks) / len(
|
| 462 |
+
controlnet_block_samples
|
| 463 |
+
)
|
| 464 |
+
interval_control = int(np.ceil(interval_control))
|
| 465 |
+
hidden_states = (
|
| 466 |
+
hidden_states
|
| 467 |
+
+ controlnet_block_samples[index_block // interval_control]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 471 |
+
|
| 472 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 473 |
+
if self.training and self.gradient_checkpointing:
|
| 474 |
+
|
| 475 |
+
def create_custom_forward(module, return_dict=None):
|
| 476 |
+
def custom_forward(*inputs):
|
| 477 |
+
if return_dict is not None:
|
| 478 |
+
return module(*inputs, return_dict=return_dict)
|
| 479 |
+
else:
|
| 480 |
+
return module(*inputs)
|
| 481 |
+
|
| 482 |
+
return custom_forward
|
| 483 |
+
|
| 484 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 485 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 486 |
+
)
|
| 487 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 488 |
+
create_custom_forward(block),
|
| 489 |
+
hidden_states,
|
| 490 |
+
temb,
|
| 491 |
+
image_rotary_emb,
|
| 492 |
+
**ckpt_kwargs,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
else:
|
| 496 |
+
hidden_states = block(
|
| 497 |
+
hidden_states=hidden_states,
|
| 498 |
+
temb=temb,
|
| 499 |
+
image_rotary_emb=image_rotary_emb,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# controlnet residual
|
| 503 |
+
if controlnet_single_block_samples is not None:
|
| 504 |
+
interval_control = len(self.single_transformer_blocks) / len(
|
| 505 |
+
controlnet_single_block_samples
|
| 506 |
+
)
|
| 507 |
+
interval_control = int(np.ceil(interval_control))
|
| 508 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 509 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 510 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 514 |
+
|
| 515 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 516 |
+
output = self.proj_out(hidden_states)
|
| 517 |
+
|
| 518 |
+
if USE_PEFT_BACKEND:
|
| 519 |
+
# remove `lora_scale` from each PEFT layer
|
| 520 |
+
unscale_lora_layers(self, lora_scale)
|
| 521 |
+
|
| 522 |
+
if not return_dict:
|
| 523 |
+
return (output,)
|
| 524 |
+
|
| 525 |
+
return Transformer2DModelOutput(sample=output)
|