Jitesh Dhamaniya commited on
Commit
8c46fdc
·
1 Parent(s): f0357cb
Files changed (5) hide show
  1. README.md +98 -0
  2. config.json +20 -0
  3. handler.py +57 -0
  4. main.py +59 -0
  5. transformer_flux.py +525 -0
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md
5
+ language:
6
+ - en
7
+ base_model: black-forest-labs/FLUX.1-dev
8
+ library_name: diffusers
9
+ tags:
10
+ - Text-to-Image
11
+ - ControlNet
12
+ - Inpainting
13
+ - FLUX
14
+ - Stable Diffusion
15
+ ---
16
+ <div style="display: flex; justify-content: center; align-items: center;">
17
+ <img src="images/alibaba.png" alt="alibaba" style="width: 20%; height: auto; margin-right: 5%;">
18
+ <img src="images/alimama.png" alt="alimama" style="width: 20%; height: auto;">
19
+ </div>
20
+
21
+ This repository provides a Inpainting ControlNet checkpoint for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) model released by AlimamaCreative Team.
22
+
23
+ ## Beta Version Now Available
24
+
25
+ We are excited to announce the release of our beta version, which brings further enhancements to our inpainting capabilities:
26
+
27
+ To access and test the beta version, please visit our [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta). We welcome your feedback and suggestions as we continue to refine and improve our model.
28
+
29
+ ## News
30
+
31
+
32
+ 🎉 Thanks to @comfyanonymous,ComfyUI now supports inference for Alimama inpainting ControlNet. Workflow can be downloaded from [here](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/alimama-flux-controlnet-inpaint.json).
33
+
34
+ ComfyUI Usage Tips:
35
+
36
+ * Using the `t5xxl-FP16` and `flux1-dev-fp8` models for 28-step inference, the GPU memory usage is 27GB. The inference time with `cfg=3.5` is 27 seconds, while without `cfg=1` it is 15 seconds. `Hyper-FLUX-lora` can be used to accelerate inference.
37
+ * You can try adjusting(lower) the parameters `control-strength`, `control-end-percent`, and `cfg` to achieve better results.
38
+ * The following example uses `control-strength` = 0.9 & `control-end-percent` = 1.0 & `cfg` = 3.5
39
+
40
+ | Input | Output | Prompt |
41
+ |------------------------------|------------------------------|-------------|
42
+ | ![Image1](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_1.png) | ![Image2](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_1.png) | <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, <span style="color:red; font-weight:bold;">Elon Musk</span>, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. |
43
+ | ![Image3](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_2.png) | ![Image4](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_2.png) | <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with <span style="color:red; font-weight:bold;">a cat</span> on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. </i></small>|
44
+ | ![Image5](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_3.png) | ![Image6](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_3.png) | <small><i>A woman with blonde hair is sitting on a table wearing a <span style="color:red; font-weight:bold;">red and white long dress</span>. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. </i></small>|
45
+ | ![Image7](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_4.png) | ![Image8](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_4.png) | <small><i>The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a <span style="color:red; font-weight:bold;">red pencil</span> in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits. </i></small>|
46
+
47
+
48
+ ## Model Cards
49
+
50
+ * The model was trained on 12M laion2B and internal source images at resolution 768x768. The inference performs best at this size, with other sizes yielding suboptimal results.
51
+
52
+ * The recommended controlnet_conditioning_scale is 0.9 - 0.95.
53
+
54
+ * **Please note: This is only the alpha version during the training process. We will release an updated version when we feel ready.**
55
+
56
+ ## Showcase
57
+
58
+ ![flux1](images/flux1.jpg)
59
+ ![flux2](images/flux2.jpg)
60
+ ![flux3](images/flux3.jpg)
61
+
62
+ ## Comparison with SDXL-Inpainting
63
+
64
+ Compared with [SDXL-Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
65
+
66
+ From left to right: Input image | Masked image | SDXL inpainting | Ours
67
+
68
+ ![0](images/0.jpg)
69
+ <small><i>*The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits.*</i></small>
70
+
71
+ ![0](images/1.jpg)
72
+ <small><i>A woman with blonde hair is sitting on a table wearing a blue and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene.</i></small>
73
+
74
+ ![0](images/2.jpg)
75
+ <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cup of coffee on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. There are several cups and a cake on the table in the background. The man sitting at the table appears to be typing on the laptop.</i></small>
76
+
77
+ ![0](images/3.jpg)
78
+ <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Naruto, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal.</i></small>
79
+
80
+ ## Using with Diffusers
81
+
82
+ Step1: install diffusers
83
+ ``` Shell
84
+ pip install diffusers==0.30.2
85
+ ```
86
+
87
+ Step2: clone repo from github
88
+ ``` Shell
89
+ git clone https://github.com/alimama-creative/FLUX-Controlnet-Inpainting.git
90
+ ```
91
+
92
+ Step3: modify the image_path, mask_path, prompt and run
93
+ ``` Shell
94
+ python main.py
95
+ ```
96
+
97
+ ## LICENSE
98
+ Our weights fall under the [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) Non-Commercial License.
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FluxControlNetModel",
3
+ "_diffusers_version": "0.30.2",
4
+ "_name_or_path": "/data/oss_bucket_0/linjinpeng.ljp/exp_flux/r768_bs2_ga6_adamw_lr5e-6_bf16_cfg3.5_sin0_dou6_s1/checkpoint-77500",
5
+ "attention_head_dim": 128,
6
+ "axes_dims_rope": [
7
+ 16,
8
+ 56,
9
+ 56
10
+ ],
11
+ "extra_condition_channels": 4,
12
+ "guidance_embeds": true,
13
+ "in_channels": 64,
14
+ "joint_attention_dim": 4096,
15
+ "num_attention_heads": 24,
16
+ "num_layers": 6,
17
+ "num_single_layers": 0,
18
+ "patch_size": 1,
19
+ "pooled_projection_dim": 768
20
+ }
handler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import DiffusionPipeline
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import base64
7
+
8
+ # Load the model once (caching for efficiency)
9
+ MODEL_ID = "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"
10
+ CONTROLNET_MODEL = "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"
11
+ TRANSFORMER_MODEL = "black-forest-labs/FLUX.1-dev"
12
+
13
+ controlnet = DiffusionPipeline.from_pretrained(CONTROLNET_MODEL, torch_dtype=torch.bfloat16)
14
+ transformer = DiffusionPipeline.from_pretrained(TRANSFORMER_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16)
15
+
16
+ pipeline = DiffusionPipeline.from_pretrained(
17
+ MODEL_ID,
18
+ controlnet=controlnet,
19
+ transformer=transformer,
20
+ torch_dtype=torch.bfloat16
21
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # Function to handle inference
24
+ def handle(inputs, context):
25
+ try:
26
+ # Parse inputs
27
+ prompt = inputs.get("prompt", "default prompt text")
28
+ control_image_base64 = inputs.get("control_image")
29
+ mask_image_base64 = inputs.get("mask_image")
30
+ num_inference_steps = inputs.get("num_inference_steps", 28)
31
+ guidance_scale = inputs.get("guidance_scale", 3.5)
32
+ controlnet_conditioning_scale = inputs.get("controlnet_conditioning_scale", 0.9)
33
+
34
+ # Convert Base64 images to PIL format
35
+ control_image = Image.open(BytesIO(base64.b64decode(control_image_base64))).convert("RGB")
36
+ mask_image = Image.open(BytesIO(base64.b64decode(mask_image_base64))).convert("RGB")
37
+
38
+ # Perform inference
39
+ result = pipeline(
40
+ prompt=prompt,
41
+ control_image=control_image,
42
+ control_mask=mask_image,
43
+ num_inference_steps=num_inference_steps,
44
+ guidance_scale=guidance_scale,
45
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
46
+ ).images[0]
47
+
48
+ # Convert result to Base64 string
49
+ buffered = BytesIO()
50
+ result.save(buffered, format="PNG")
51
+ result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
52
+
53
+ # Return the result
54
+ return {"status": "success", "image": result_base64}
55
+
56
+ except Exception as e:
57
+ return {"status": "error", "message": str(e)}
main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils import load_image, check_min_version
3
+ from controlnet_flux import FluxControlNetModel
4
+ from transformer_flux import FluxTransformer2DModel
5
+ from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
6
+ import os
7
+
8
+ check_min_version("0.30.2")
9
+
10
+ # Set image path , mask path and prompt
11
+ # image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png',
12
+ # mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg',
13
+
14
+ current_dir = os.path.dirname(os.path.abspath(__file__))
15
+ image_path = os.path.join(current_dir, "images_bucket.png")
16
+ mask_path = os.path.join(current_dir, "images_bucket_mask.jpeg")
17
+
18
+
19
+ prompt='a person wearing a white shoe, carrying a white bucket with text "Jitesh" on it with red color'
20
+
21
+ # print(f"Checkpoint file: {checkpoint_file}")
22
+
23
+ # Build pipeline
24
+ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16, local_files_only=True)
25
+ transformer = FluxTransformer2DModel.from_pretrained(
26
+ "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16, local_files_only=True
27
+ )
28
+ pipe = FluxControlNetInpaintingPipeline.from_pretrained(
29
+ "black-forest-labs/FLUX.1-dev",
30
+ controlnet=controlnet,
31
+ transformer=transformer,
32
+ torch_dtype=torch.bfloat16
33
+ ).to("mps")
34
+ pipe.transformer.to(torch.bfloat16)
35
+ pipe.controlnet.to(torch.bfloat16)
36
+
37
+ # Load image and mask
38
+ size = (768, 768)
39
+ image = load_image(image_path).convert("RGB").resize(size)
40
+ mask = load_image(mask_path).convert("RGB").resize(size)
41
+ generator = torch.Generator(device="mps").manual_seed(24)
42
+
43
+ # Inpaint
44
+ result = pipe(
45
+ prompt=prompt,
46
+ height=size[1],
47
+ width=size[0],
48
+ control_image=image,
49
+ control_mask=mask,
50
+ num_inference_steps=28,
51
+ generator=generator,
52
+ controlnet_conditioning_scale=0.9,
53
+ guidance_scale=3.5,
54
+ negative_prompt="",
55
+ true_guidance_scale=1.0 # default: 3.5 for alpha and 1.0 for beta
56
+ ).images[0]
57
+
58
+ result.save('flux_inpaint.png')
59
+ print("Successfully inpaint image")
transformer_flux.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ FluxAttnProcessor2_0,
14
+ FluxSingleAttnProcessor2_0,
15
+ )
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.models.normalization import (
18
+ AdaLayerNormContinuous,
19
+ AdaLayerNormZero,
20
+ AdaLayerNormZeroSingle,
21
+ )
22
+ from diffusers.utils import (
23
+ USE_PEFT_BACKEND,
24
+ is_torch_version,
25
+ logging,
26
+ scale_lora_layers,
27
+ unscale_lora_layers,
28
+ )
29
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
30
+ from diffusers.models.embeddings import (
31
+ CombinedTimestepGuidanceTextProjEmbeddings,
32
+ CombinedTimestepTextProjEmbeddings,
33
+ )
34
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ # YiYi to-do: refactor rope related functions/classes
41
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
42
+ assert dim % 2 == 0, "The dimension must be even."
43
+
44
+ scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
45
+ omega = 1.0 / (theta**scale)
46
+
47
+ batch_size, seq_length = pos.shape
48
+ out = torch.einsum("...n,d->...nd", pos, omega)
49
+ cos_out = torch.cos(out)
50
+ sin_out = torch.sin(out)
51
+
52
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
53
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
54
+ return out.float()
55
+
56
+
57
+ # YiYi to-do: refactor rope related functions/classes
58
+ class EmbedND(nn.Module):
59
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
60
+ super().__init__()
61
+ self.dim = dim
62
+ self.theta = theta
63
+ self.axes_dim = axes_dim
64
+
65
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
66
+ n_axes = ids.shape[-1]
67
+ emb = torch.cat(
68
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
69
+ dim=-3,
70
+ )
71
+ return emb.unsqueeze(1)
72
+
73
+
74
+ @maybe_allow_in_graph
75
+ class FluxSingleTransformerBlock(nn.Module):
76
+ r"""
77
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
78
+
79
+ Reference: https://arxiv.org/abs/2403.03206
80
+
81
+ Parameters:
82
+ dim (`int`): The number of channels in the input and output.
83
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
84
+ attention_head_dim (`int`): The number of channels in each head.
85
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
86
+ processing of `context` conditions.
87
+ """
88
+
89
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
90
+ super().__init__()
91
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
92
+
93
+ self.norm = AdaLayerNormZeroSingle(dim)
94
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
95
+ self.act_mlp = nn.GELU(approximate="tanh")
96
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
97
+
98
+ processor = FluxSingleAttnProcessor2_0()
99
+ self.attn = Attention(
100
+ query_dim=dim,
101
+ cross_attention_dim=None,
102
+ dim_head=attention_head_dim,
103
+ heads=num_attention_heads,
104
+ out_dim=dim,
105
+ bias=True,
106
+ processor=processor,
107
+ qk_norm="rms_norm",
108
+ eps=1e-6,
109
+ pre_only=True,
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.FloatTensor,
115
+ temb: torch.FloatTensor,
116
+ image_rotary_emb=None,
117
+ ):
118
+ residual = hidden_states
119
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
120
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
121
+
122
+ attn_output = self.attn(
123
+ hidden_states=norm_hidden_states,
124
+ image_rotary_emb=image_rotary_emb,
125
+ )
126
+
127
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
128
+ gate = gate.unsqueeze(1)
129
+ hidden_states = gate * self.proj_out(hidden_states)
130
+ hidden_states = residual + hidden_states
131
+ if hidden_states.dtype == torch.float16:
132
+ hidden_states = hidden_states.clip(-65504, 65504)
133
+
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class FluxTransformerBlock(nn.Module):
139
+ r"""
140
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
141
+
142
+ Reference: https://arxiv.org/abs/2403.03206
143
+
144
+ Parameters:
145
+ dim (`int`): The number of channels in the input and output.
146
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
147
+ attention_head_dim (`int`): The number of channels in each head.
148
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
149
+ processing of `context` conditions.
150
+ """
151
+
152
+ def __init__(
153
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
154
+ ):
155
+ super().__init__()
156
+
157
+ self.norm1 = AdaLayerNormZero(dim)
158
+
159
+ self.norm1_context = AdaLayerNormZero(dim)
160
+
161
+ if hasattr(F, "scaled_dot_product_attention"):
162
+ processor = FluxAttnProcessor2_0()
163
+ else:
164
+ raise ValueError(
165
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
166
+ )
167
+ self.attn = Attention(
168
+ query_dim=dim,
169
+ cross_attention_dim=None,
170
+ added_kv_proj_dim=dim,
171
+ dim_head=attention_head_dim,
172
+ heads=num_attention_heads,
173
+ out_dim=dim,
174
+ context_pre_only=False,
175
+ bias=True,
176
+ processor=processor,
177
+ qk_norm=qk_norm,
178
+ eps=eps,
179
+ )
180
+
181
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
182
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
183
+
184
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
185
+ self.ff_context = FeedForward(
186
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
187
+ )
188
+
189
+ # let chunk size default to None
190
+ self._chunk_size = None
191
+ self._chunk_dim = 0
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.FloatTensor,
196
+ encoder_hidden_states: torch.FloatTensor,
197
+ temb: torch.FloatTensor,
198
+ image_rotary_emb=None,
199
+ ):
200
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
201
+ hidden_states, emb=temb
202
+ )
203
+
204
+ (
205
+ norm_encoder_hidden_states,
206
+ c_gate_msa,
207
+ c_shift_mlp,
208
+ c_scale_mlp,
209
+ c_gate_mlp,
210
+ ) = self.norm1_context(encoder_hidden_states, emb=temb)
211
+
212
+ # Attention.
213
+ attn_output, context_attn_output = self.attn(
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ image_rotary_emb=image_rotary_emb,
217
+ )
218
+
219
+ # Process attention outputs for the `hidden_states`.
220
+ attn_output = gate_msa.unsqueeze(1) * attn_output
221
+ hidden_states = hidden_states + attn_output
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = (
225
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
226
+ )
227
+
228
+ ff_output = self.ff(norm_hidden_states)
229
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
230
+
231
+ hidden_states = hidden_states + ff_output
232
+
233
+ # Process attention outputs for the `encoder_hidden_states`.
234
+
235
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
236
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
237
+
238
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
239
+ norm_encoder_hidden_states = (
240
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
241
+ + c_shift_mlp[:, None]
242
+ )
243
+
244
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
245
+ encoder_hidden_states = (
246
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
247
+ )
248
+ if encoder_hidden_states.dtype == torch.float16:
249
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
250
+
251
+ return encoder_hidden_states, hidden_states
252
+
253
+
254
+ class FluxTransformer2DModel(
255
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
256
+ ):
257
+ """
258
+ The Transformer model introduced in Flux.
259
+
260
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
261
+
262
+ Parameters:
263
+ patch_size (`int`): Patch size to turn the input data into small patches.
264
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
265
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
266
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
267
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
268
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
269
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
270
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
271
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
272
+ """
273
+
274
+ _supports_gradient_checkpointing = True
275
+
276
+ @register_to_config
277
+ def __init__(
278
+ self,
279
+ patch_size: int = 1,
280
+ in_channels: int = 64,
281
+ num_layers: int = 19,
282
+ num_single_layers: int = 38,
283
+ attention_head_dim: int = 128,
284
+ num_attention_heads: int = 24,
285
+ joint_attention_dim: int = 4096,
286
+ pooled_projection_dim: int = 768,
287
+ guidance_embeds: bool = False,
288
+ axes_dims_rope: List[int] = [16, 56, 56],
289
+ ):
290
+ super().__init__()
291
+ self.out_channels = in_channels
292
+ self.inner_dim = (
293
+ self.config.num_attention_heads * self.config.attention_head_dim
294
+ )
295
+
296
+ self.pos_embed = EmbedND(
297
+ dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
298
+ )
299
+ text_time_guidance_cls = (
300
+ CombinedTimestepGuidanceTextProjEmbeddings
301
+ if guidance_embeds
302
+ else CombinedTimestepTextProjEmbeddings
303
+ )
304
+ self.time_text_embed = text_time_guidance_cls(
305
+ embedding_dim=self.inner_dim,
306
+ pooled_projection_dim=self.config.pooled_projection_dim,
307
+ )
308
+
309
+ self.context_embedder = nn.Linear(
310
+ self.config.joint_attention_dim, self.inner_dim
311
+ )
312
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
313
+
314
+ self.transformer_blocks = nn.ModuleList(
315
+ [
316
+ FluxTransformerBlock(
317
+ dim=self.inner_dim,
318
+ num_attention_heads=self.config.num_attention_heads,
319
+ attention_head_dim=self.config.attention_head_dim,
320
+ )
321
+ for i in range(self.config.num_layers)
322
+ ]
323
+ )
324
+
325
+ self.single_transformer_blocks = nn.ModuleList(
326
+ [
327
+ FluxSingleTransformerBlock(
328
+ dim=self.inner_dim,
329
+ num_attention_heads=self.config.num_attention_heads,
330
+ attention_head_dim=self.config.attention_head_dim,
331
+ )
332
+ for i in range(self.config.num_single_layers)
333
+ ]
334
+ )
335
+
336
+ self.norm_out = AdaLayerNormContinuous(
337
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
338
+ )
339
+ self.proj_out = nn.Linear(
340
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
341
+ )
342
+
343
+ self.gradient_checkpointing = False
344
+
345
+ def _set_gradient_checkpointing(self, module, value=False):
346
+ if hasattr(module, "gradient_checkpointing"):
347
+ module.gradient_checkpointing = value
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states: torch.Tensor,
352
+ encoder_hidden_states: torch.Tensor = None,
353
+ pooled_projections: torch.Tensor = None,
354
+ timestep: torch.LongTensor = None,
355
+ img_ids: torch.Tensor = None,
356
+ txt_ids: torch.Tensor = None,
357
+ guidance: torch.Tensor = None,
358
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
359
+ controlnet_block_samples=None,
360
+ controlnet_single_block_samples=None,
361
+ return_dict: bool = True,
362
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
363
+ """
364
+ The [`FluxTransformer2DModel`] forward method.
365
+
366
+ Args:
367
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
368
+ Input `hidden_states`.
369
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
370
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
371
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
372
+ from the embeddings of input conditions.
373
+ timestep ( `torch.LongTensor`):
374
+ Used to indicate denoising step.
375
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
376
+ A list of tensors that if specified are added to the residuals of transformer blocks.
377
+ joint_attention_kwargs (`dict`, *optional*):
378
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
379
+ `self.processor` in
380
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
383
+ tuple.
384
+
385
+ Returns:
386
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
387
+ `tuple` where the first element is the sample tensor.
388
+ """
389
+ if joint_attention_kwargs is not None:
390
+ joint_attention_kwargs = joint_attention_kwargs.copy()
391
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
392
+ else:
393
+ lora_scale = 1.0
394
+
395
+ if USE_PEFT_BACKEND:
396
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
397
+ scale_lora_layers(self, lora_scale)
398
+ else:
399
+ if (
400
+ joint_attention_kwargs is not None
401
+ and joint_attention_kwargs.get("scale", None) is not None
402
+ ):
403
+ logger.warning(
404
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
405
+ )
406
+ hidden_states = self.x_embedder(hidden_states)
407
+
408
+ timestep = timestep.to(hidden_states.dtype) * 1000
409
+ if guidance is not None:
410
+ guidance = guidance.to(hidden_states.dtype) * 1000
411
+ else:
412
+ guidance = None
413
+ temb = (
414
+ self.time_text_embed(timestep, pooled_projections)
415
+ if guidance is None
416
+ else self.time_text_embed(timestep, guidance, pooled_projections)
417
+ )
418
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
419
+
420
+ txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
421
+ ids = torch.cat((txt_ids, img_ids), dim=1)
422
+ image_rotary_emb = self.pos_embed(ids)
423
+
424
+ for index_block, block in enumerate(self.transformer_blocks):
425
+ if self.training and self.gradient_checkpointing:
426
+
427
+ def create_custom_forward(module, return_dict=None):
428
+ def custom_forward(*inputs):
429
+ if return_dict is not None:
430
+ return module(*inputs, return_dict=return_dict)
431
+ else:
432
+ return module(*inputs)
433
+
434
+ return custom_forward
435
+
436
+ ckpt_kwargs: Dict[str, Any] = (
437
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
438
+ )
439
+ (
440
+ encoder_hidden_states,
441
+ hidden_states,
442
+ ) = torch.utils.checkpoint.checkpoint(
443
+ create_custom_forward(block),
444
+ hidden_states,
445
+ encoder_hidden_states,
446
+ temb,
447
+ image_rotary_emb,
448
+ **ckpt_kwargs,
449
+ )
450
+
451
+ else:
452
+ encoder_hidden_states, hidden_states = block(
453
+ hidden_states=hidden_states,
454
+ encoder_hidden_states=encoder_hidden_states,
455
+ temb=temb,
456
+ image_rotary_emb=image_rotary_emb,
457
+ )
458
+
459
+ # controlnet residual
460
+ if controlnet_block_samples is not None:
461
+ interval_control = len(self.transformer_blocks) / len(
462
+ controlnet_block_samples
463
+ )
464
+ interval_control = int(np.ceil(interval_control))
465
+ hidden_states = (
466
+ hidden_states
467
+ + controlnet_block_samples[index_block // interval_control]
468
+ )
469
+
470
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
471
+
472
+ for index_block, block in enumerate(self.single_transformer_blocks):
473
+ if self.training and self.gradient_checkpointing:
474
+
475
+ def create_custom_forward(module, return_dict=None):
476
+ def custom_forward(*inputs):
477
+ if return_dict is not None:
478
+ return module(*inputs, return_dict=return_dict)
479
+ else:
480
+ return module(*inputs)
481
+
482
+ return custom_forward
483
+
484
+ ckpt_kwargs: Dict[str, Any] = (
485
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
486
+ )
487
+ hidden_states = torch.utils.checkpoint.checkpoint(
488
+ create_custom_forward(block),
489
+ hidden_states,
490
+ temb,
491
+ image_rotary_emb,
492
+ **ckpt_kwargs,
493
+ )
494
+
495
+ else:
496
+ hidden_states = block(
497
+ hidden_states=hidden_states,
498
+ temb=temb,
499
+ image_rotary_emb=image_rotary_emb,
500
+ )
501
+
502
+ # controlnet residual
503
+ if controlnet_single_block_samples is not None:
504
+ interval_control = len(self.single_transformer_blocks) / len(
505
+ controlnet_single_block_samples
506
+ )
507
+ interval_control = int(np.ceil(interval_control))
508
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
509
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
510
+ + controlnet_single_block_samples[index_block // interval_control]
511
+ )
512
+
513
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
514
+
515
+ hidden_states = self.norm_out(hidden_states, temb)
516
+ output = self.proj_out(hidden_states)
517
+
518
+ if USE_PEFT_BACKEND:
519
+ # remove `lora_scale` from each PEFT layer
520
+ unscale_lora_layers(self, lora_scale)
521
+
522
+ if not return_dict:
523
+ return (output,)
524
+
525
+ return Transformer2DModelOutput(sample=output)