WonwoongCho commited on
Commit
4185a37
·
0 Parent(s):

Clean slate with jpg

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: IT Blender
3
+ emoji: 💡
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.33.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import spaces
5
+ from PIL import Image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from diffusers import FluxPipeline
9
+ from attention_processor import FluxBlendedAttnProcessor2_0
10
+
11
+ from utils_sample import set_seed, resize_and_add_margin
12
+ import torchvision.transforms.functional as F
13
+
14
+
15
+ dtype = torch.bfloat16
16
+
17
+ pipe = FluxPipeline.from_pretrained(
18
+ "black-forest-labs/FLUX.1-dev", torch_dtype=dtype
19
+ )
20
+ pipe = pipe.to("cuda")
21
+
22
+
23
+ @spaces.GPU
24
+ def process_image_and_text(image, scale, seed, text):
25
+ set_seed(seed)
26
+ blended_attn_procs = {}
27
+ for name, _ in pipe.transformer.attn_processors.items():
28
+ if "single" in name:
29
+ blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
30
+ else:
31
+ blended_attn_procs[name] = pipe.transformer.attn_processors[name]
32
+
33
+ pipe.transformer.set_attn_processor(blended_attn_procs)
34
+ pipe.to(dtype)
35
+
36
+ model_path = hf_hub_download(
37
+ repo_id="Wonwoong/IT-Blender",
38
+ filename="FLUX/it-blender.bin" # adjust the filename as needed
39
+ )
40
+ pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
41
+
42
+ key_changed_blended_attn_weights = {}
43
+ for key, value in pretrained_blended_attn_weights.items():
44
+ block_idx = int(key.split(".")[0]) - 21
45
+ k_or_v = key.split("_")[2]
46
+ changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
47
+ key_changed_blended_attn_weights[changed_key] = value.to(dtype)
48
+
49
+ missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
50
+
51
+ image = Image.open(img_path).convert('RGB')
52
+ image = resize_and_add_margin(image, target_size=512)
53
+
54
+ image_list = [image]
55
+
56
+ out = pipe(
57
+ prompt=prompt,
58
+ height=512,
59
+ width=512,
60
+ max_sequence_length=256,
61
+ generator=torch.Generator().manual_seed(seed+10*j),
62
+ it_blender_image=image_list
63
+ ).images[0]
64
+
65
+ return out
66
+
67
+
68
+ def get_samples():
69
+ sample_list = [
70
+ {
71
+ "image": "assets/0.jpg",
72
+ "scale": 0.6,
73
+ "seed": 42,
74
+ "text": "A photo of a monster cartoon character, imaginative, creative, design",
75
+ },
76
+ {
77
+ "image": "assets/1.jpg",
78
+ "scale": 0.6,
79
+ "seed": 42,
80
+ "text": "A photo of an owl cartoon character, imaginative, creative, design",
81
+ },
82
+ {
83
+ "image": "assets/2.jpg",
84
+ "scale": 0.6,
85
+ "seed": 42,
86
+ "text": "A photo of a dragon, imaginative, creative, design",
87
+ },
88
+ {
89
+ "image": "assets/character1.jpg",
90
+ "scale": 0.6,
91
+ "seed": 42,
92
+ "text": "A photo of a dragon, imaginative, creative, design",
93
+ },
94
+ {
95
+ "image": "assets/character2.jpg",
96
+ "scale": 0.6,
97
+ "seed": 42,
98
+ "text": "A photo of a dragon, imaginative, creative, design",
99
+ },
100
+ {
101
+ "image": "assets/character3.jpg",
102
+ "scale": 0.6,
103
+ "seed": 42,
104
+ "text": "A photo of a dragon, imaginative, creative, design",
105
+ },
106
+ {
107
+ "image": "assets/graphic1.jpg",
108
+ "scale": 0.7,
109
+ "seed": 42,
110
+ "text": "A photo of a woman, imaginative, creative, design",
111
+ },
112
+ {
113
+ "image": "assets/product1.jpg",
114
+ "scale": 0.8,
115
+ "seed": 42,
116
+ "text": "A photo of a motorcycle, imaginative, creative, design",
117
+ }
118
+
119
+ ]
120
+ return [
121
+ [
122
+ Image.open(sample["image"]).resize((512, 512)),
123
+ sample["scale"],
124
+ sample["seed"],
125
+ sample["text"],
126
+ ]
127
+ for sample in sample_list
128
+ ]
129
+
130
+
131
+ header = """
132
+ # 💡 IT-Blender / FLUX
133
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
134
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ArXiv-Paper-A42C25.svg" alt="arXiv"></a>
135
+ <a href="https://imagineforme.github.io/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-ITBlender-yellow"></a>
136
+ <a href="https://github.com/WonwoongCho/IT-Blender"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
137
+ </div>
138
+ """
139
+
140
+
141
+ def create_app():
142
+
143
+ with gr.Blocks() as app:
144
+ gr.Markdown(header, elem_id="header")
145
+ with gr.Row(equal_height=False):
146
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
147
+ original_image = gr.Image(
148
+ type="pil", label="Condition Image", width=300, elem_id="input"
149
+ )
150
+
151
+ scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Guidance Scale")
152
+ seed = gr.Number(value=42, label="seed", precision=0)
153
+ text = gr.Textbox(lines=2, label="Text Prompt", elem_id="text")
154
+
155
+ submit_btn = gr.Button("Run", elem_id="submit_btn")
156
+
157
+ with gr.Column(variant="panel", elem_classes="outputPanel"):
158
+ output_image = gr.Image(type="pil", elem_id="output")
159
+
160
+ with gr.Row():
161
+ examples = gr.Examples(
162
+ examples=get_samples(),
163
+ inputs=[original_image, scale, seed, text],
164
+ label="Examples",
165
+ )
166
+
167
+ submit_btn.click(
168
+ fn=process_image_and_text,
169
+ inputs=[original_image, scale, seed, text],
170
+ outputs=output_image,
171
+ )
172
+
173
+ return app
174
+
175
+
176
+ if __name__ == "__main__":
177
+ create_app().launch(debug=True, ssr_mode=False)
assets/0.jpg ADDED
assets/1.jpg ADDED
assets/2.jpg ADDED
assets/character1.jpg ADDED
assets/character2.jpg ADDED
assets/character3.jpg ADDED
assets/graphic1.jpg ADDED
assets/product1.jpg ADDED
convert_png_to_jpg.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+
4
+ def convert_png_to_jpg(input_folder, output_folder=None, quality=85):
5
+ if output_folder is None:
6
+ output_folder = input_folder
7
+
8
+ os.makedirs(output_folder, exist_ok=True)
9
+
10
+ for filename in os.listdir(input_folder):
11
+ if filename.lower().endswith(".png"):
12
+ png_path = os.path.join(input_folder, filename)
13
+ jpg_name = os.path.splitext(filename)[0] + ".jpg"
14
+ jpg_path = os.path.join(output_folder, jpg_name)
15
+
16
+ with Image.open(png_path) as img:
17
+ rgb_img = img.convert("RGB") # remove alpha
18
+ rgb_img.save(jpg_path, "JPEG", quality=quality)
19
+
20
+ print(f"Converted: {filename} → {jpg_name}")
21
+
22
+ # Example usage
23
+ convert_png_to_jpg("assets/")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ protobuf
3
+ sentencepiece
4
+ accelerate
5
+ einops
6
+ huggingface_hub
7
+ git+https://github.com/WonwoongCho/diffusers@main#egg=diffusers
src/attention_processor.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+ from typing import Callable, List, Optional, Tuple, Union
6
+
7
+ class FluxBlendedAttnProcessor2_0(nn.Module):
8
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
9
+
10
+ def __init__(self, hidden_dim, ba_scale=1.0, num_ref=1, temperature=1.2):
11
+ super().__init__()
12
+ if not hasattr(F, "scaled_dot_product_attention"):
13
+ raise ImportError("FluxBlendedAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
14
+
15
+ self.blended_attention_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
16
+ self.blended_attention_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
17
+ self.ba_scale = ba_scale
18
+ self.num_ref = num_ref
19
+ self.temperature = temperature # this is used only when num_ref > 1
20
+
21
+ def __call__(
22
+ self,
23
+ attn, #: Attention,
24
+ hidden_states: torch.FloatTensor,
25
+ encoder_hidden_states: torch.FloatTensor = None,
26
+ attention_mask: Optional[torch.FloatTensor] = None,
27
+ image_rotary_emb: Optional[torch.Tensor] = None,
28
+ is_negative_prompt: bool = False
29
+ ) -> torch.FloatTensor:
30
+ assert encoder_hidden_states is None, "It should be given as None because we are applying it-blender only to the single streams."
31
+ batch_size, _, _ = hidden_states.shape
32
+
33
+ # `sample` projections.
34
+ query = attn.to_q(hidden_states)
35
+ key = attn.to_k(hidden_states)
36
+ value = attn.to_v(hidden_states)
37
+
38
+ inner_dim = key.shape[-1]
39
+ head_dim = inner_dim // attn.heads
40
+
41
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
42
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
43
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
44
+
45
+ if attn.norm_q is not None:
46
+ normalized_query = attn.norm_q(query)
47
+ if attn.norm_k is not None:
48
+ key = attn.norm_k(key)
49
+
50
+ if image_rotary_emb is not None:
51
+ from diffusers.models.embeddings import apply_rotary_emb
52
+
53
+ query = apply_rotary_emb(normalized_query, image_rotary_emb)
54
+ key = apply_rotary_emb(key, image_rotary_emb)
55
+
56
+
57
+ hidden_states = F.scaled_dot_product_attention(
58
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
59
+ )
60
+
61
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
62
+ hidden_states = hidden_states.to(query.dtype)
63
+
64
+
65
+ # [noisy, clean]
66
+ chunk = batch_size//(1+self.num_ref)
67
+ ba_query = normalized_query[:chunk] # noisy query
68
+
69
+ ba_key = self.blended_attention_k_proj(hidden_states[chunk:]) # clean key
70
+ ba_value = self.blended_attention_v_proj(hidden_states[chunk:]) # clean value
71
+
72
+ ba_key = ba_key.view(chunk, -1, attn.heads, head_dim).transpose(1, 2) # the -1 is gonna be multiplied by self.num_ref
73
+ ba_value = ba_value.view(chunk, -1, attn.heads, head_dim).transpose(1, 2)
74
+
75
+ ba_hidden_states = F.scaled_dot_product_attention(
76
+ ba_query, ba_key, ba_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=(1 / math.sqrt(ba_query.size(-1)))*self.temperature if self.num_ref > 1 else 1 / math.sqrt(ba_query.size(-1))
77
+ )
78
+
79
+ ba_hidden_states = ba_hidden_states.transpose(1, 2).reshape(chunk, -1, attn.heads * head_dim)
80
+ ba_hidden_states = ba_hidden_states.to(query.dtype)
81
+
82
+ zero_tensor_list = [torch.zeros_like(ba_hidden_states)]*self.num_ref
83
+ ba_hidden_states = torch.cat([ba_hidden_states]+zero_tensor_list, dim=0)
84
+
85
+ hidden_states = hidden_states + self.ba_scale * ba_hidden_states
86
+
87
+ return hidden_states
src/utils_sample.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+
6
+ def set_seed(seed: int):
7
+ """
8
+ Set the seed for reproducibility across different libraries and devices.
9
+
10
+ Args:
11
+ seed (int): The seed value to set.
12
+ """
13
+ # Set seed for Python's random module
14
+ random.seed(seed)
15
+
16
+ # Set seed for NumPy
17
+ np.random.seed(seed)
18
+
19
+ # Set seed for PyTorch CPU
20
+ torch.manual_seed(seed)
21
+
22
+ # Set seed for PyTorch GPU (if using CUDA)
23
+ torch.cuda.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed) # For multi-GPU setups
25
+
26
+ # Ensure deterministic results for CUDA operations (optional)
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+
31
+ def resize_and_center_crop(image, target_size=512):
32
+ w, h = image.size
33
+ scale = target_size / min(w, h)
34
+ new_w = int(w * scale)
35
+ new_h = int(h * scale)
36
+ image_resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
37
+
38
+ left = (new_w - target_size) // 2
39
+ top = (new_h - target_size) // 2
40
+ right = left + target_size
41
+ bottom = top + target_size
42
+ image_cropped = image_resized.crop((left, top, right, bottom))
43
+
44
+ return image_cropped
45
+
46
+
47
+ def resize_and_add_margin(image, target_size=512, background_color=(255, 255, 255)):
48
+ w, h = image.size
49
+ scale = target_size / max(w, h)
50
+ new_w = int(w * scale)
51
+ new_h = int(h * scale)
52
+ image_resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
53
+
54
+ new_image = Image.new("RGB", (target_size, target_size), background_color)
55
+
56
+ left = (target_size - new_w) // 2
57
+ top = (target_size - new_h) // 2
58
+ new_image.paste(image_resized, (left, top))
59
+
60
+ return new_image