IT-Blender / app.py
WonwoongCho's picture
update app
cf0a3d5
import gradio as gr
import torch
import numpy as np
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline
from src.attention_processor import FluxBlendedAttnProcessor2_0
from src.utils_sample import set_seed, resize_and_add_margin
import os
dtype = torch.bfloat16
token = os.environ.get("HF_TOKEN")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
token=token
)
pipe = pipe.to("cuda")
@spaces.GPU
def process_image_and_text(image, scale, seed, text):
set_seed(seed)
image = resize_and_add_margin(image, target_size=512)
image_list = [image]
# Dynamically set attention processors using user-specified scale
blended_attn_procs = {}
for name, _ in pipe.transformer.attn_processors.items():
if "single" in name:
processor = FluxBlendedAttnProcessor2_0(3072, ba_scale=float(scale), num_ref=1)
processor = processor.to(device="cuda", dtype=dtype)
blended_attn_procs[name] = processor
else:
blended_attn_procs[name] = pipe.transformer.attn_processors[name]
pipe.transformer.set_attn_processor(blended_attn_procs)
model_path = hf_hub_download(
repo_id="WonwoongCho/IT-Blender",
filename="FLUX/it-blender.bin",
token=token
)
pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
key_changed_blended_attn_weights = {}
for key, value in pretrained_blended_attn_weights.items():
block_idx = int(key.split(".")[0]) - 21
k_or_v = key.split("_")[2]
changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
key_changed_blended_attn_weights[changed_key] = value.to(dtype)
missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
out = pipe(
prompt=text,
height=512,
width=512,
max_sequence_length=256,
generator=torch.Generator().manual_seed(seed),
it_blender_image=image_list
).images[0]
return out
def get_samples():
sample_list = [
{
"image": "assets/0.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a monster cartoon character, imaginative, creative, design",
},
{
"image": "assets/1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of an owl cartoon character, imaginative, creative, design",
},
{
"image": "assets/2.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a monster cartoon character, imaginative, creative, design",
},
{
"image": "assets/character1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/character2.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/character3.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/graphic1.jpg",
"scale": 0.7,
"seed": 42,
"text": "A photo of a woman, imaginative, creative, design",
},
{
"image": "assets/sneakers1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of sneakers, imaginative, creative, design",
},
{
"image": "assets/product1.jpg",
"scale": 0.8,
"seed": 42,
"text": "A photo of a motorcycle, imaginative, creative, design",
},
{
"image": "assets/art1.jpg",
"scale": 0.8,
"seed": 42,
"text": "A photo of Eiffel Tower, imaginative, creative, design",
}
]
return [
[
Image.open(sample["image"]).resize((512, 512)),
sample["scale"],
sample["seed"],
sample["text"],
]
for sample in sample_list
]
header = """
# 💡 IT-Blender / FLUX
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/pdf/2506.24085"><img src="https://img.shields.io/badge/ArXiv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://imagineforme.github.io/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-ITBlender-yellow"></a>
<a href="https://github.com/WonwoongCho/IT-Blender"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
def create_app():
with gr.Blocks() as app:
gr.Markdown(header, elem_id="header")
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_classes="inputPanel"):
original_image_input = gr.Image(
type="pil", label="Condition image", width=300, elem_id="input"
)
scale_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="Scale (recommended range: 0.5-0.8; the higher, the stronger effect of the reference image)")
seed_input = gr.Number(value=42, label="Seed", precision=0)
text_input = gr.Textbox(
lines=2,
label="Text prompt",
elem_id="text"
)
submit_btn = gr.Button("Run", elem_id="submit_btn")
with gr.Column(variant="panel", elem_classes="outputPanel"):
output_image = gr.Image(type="pil", elem_id="output")
with gr.Row():
examples = gr.Examples(
examples=get_samples(),
inputs=[original_image_input, scale_input, seed_input, text_input],
label="Examples",
)
submit_btn.click(
fn=process_image_and_text,
inputs=[original_image_input, scale_input, seed_input, text_input],
outputs=output_image,
)
return app
if __name__ == "__main__":
demo = create_app()
demo.launch(debug=True, ssr_mode=False)