IT-Blender / app.py
WonwoongCho's picture
added arxiv link and added sneakers example
ca38081
raw
history blame
6.54 kB
import gradio as gr
import torch
import numpy as np
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline
from src.attention_processor import FluxBlendedAttnProcessor2_0
from src.utils_sample import set_seed, resize_and_add_margin
import os
dtype = torch.bfloat16
token = os.environ.get("HF_TOKEN")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
token=token
)
pipe = pipe.to("cuda")
@spaces.GPU
def process_image_and_text(image, scale, seed, text):
set_seed(seed)
image = resize_and_add_margin(image, target_size=512)
image_list = [image]
# Dynamically set attention processors using user-specified scale
blended_attn_procs = {}
for name, _ in pipe.transformer.attn_processors.items():
if "single" in name:
processor = FluxBlendedAttnProcessor2_0(3072, ba_scale=float(scale), num_ref=1)
processor = processor.to(device="cuda", dtype=dtype)
blended_attn_procs[name] = processor
else:
blended_attn_procs[name] = pipe.transformer.attn_processors[name]
pipe.transformer.set_attn_processor(blended_attn_procs)
model_path = hf_hub_download(
repo_id="WonwoongCho/IT-Blender",
filename="FLUX/it-blender.bin",
token=token
)
pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
key_changed_blended_attn_weights = {}
for key, value in pretrained_blended_attn_weights.items():
block_idx = int(key.split(".")[0]) - 21
k_or_v = key.split("_")[2]
changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
key_changed_blended_attn_weights[changed_key] = value.to(dtype)
missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
out = pipe(
prompt=text,
height=512,
width=512,
max_sequence_length=256,
generator=torch.Generator().manual_seed(seed),
it_blender_image=image_list
).images[0]
return out
def get_samples():
sample_list = [
{
"image": "assets/0.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a monster cartoon character, imaginative, creative, design",
},
{
"image": "assets/1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of an owl cartoon character, imaginative, creative, design",
},
{
"image": "assets/2.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/character1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/character2.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/character3.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of a dragon, imaginative, creative, design",
},
{
"image": "assets/graphic1.jpg",
"scale": 0.7,
"seed": 42,
"text": "A photo of a woman, imaginative, creative, design",
},
{
"image": "assets/sneakers1.jpg",
"scale": 0.6,
"seed": 42,
"text": "A photo of sneakers, imaginative, creative, design",
},
{
"image": "assets/product1.jpg",
"scale": 0.8,
"seed": 42,
"text": "A photo of a motorcycle, imaginative, creative, design",
},
{
"image": "assets/art1.jpg",
"scale": 0.8,
"seed": 42,
"text": "A photo of Eiffel Tower, imaginative, creative, design",
}
]
return [
[
Image.open(sample["image"]).resize((512, 512)),
sample["scale"],
sample["seed"],
sample["text"],
]
for sample in sample_list
]
header = """
# 💡 IT-Blender / FLUX
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/pdf/2506.24085"><img src="https://img.shields.io/badge/ArXiv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://imagineforme.github.io/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-ITBlender-yellow"></a>
<a href="https://github.com/WonwoongCho/IT-Blender"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
def create_app():
with gr.Blocks() as app:
gr.Markdown(header, elem_id="header")
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_classes="inputPanel"):
original_image_input = gr.Image(
type="pil", label="Condition image", width=300, elem_id="input"
)
scale_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="Scale (recommended range: 0.5-0.8; the higher, the stronger effect of the reference image)")
seed_input = gr.Number(value=42, label="Seed", precision=0)
text_input = gr.Textbox(
lines=2,
label="Text prompt",
value="A photo of a dragon, imaginative, creative, design",
elem_id="text"
)
submit_btn = gr.Button("Run", elem_id="submit_btn")
with gr.Column(variant="panel", elem_classes="outputPanel"):
output_image = gr.Image(type="pil", elem_id="output")
with gr.Row():
examples = gr.Examples(
examples=get_samples(),
inputs=[original_image_input, scale_input, seed_input, text_input],
label="Examples",
)
submit_btn.click(
fn=process_image_and_text,
inputs=[original_image_input, scale_input, seed_input, text_input],
outputs=output_image,
)
return app
if __name__ == "__main__":
demo = create_app()
demo.launch(debug=True, ssr_mode=False)