Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import spaces | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import FluxPipeline | |
| from src.attention_processor import FluxBlendedAttnProcessor2_0 | |
| from src.utils_sample import set_seed, resize_and_add_margin | |
| import os | |
| dtype = torch.bfloat16 | |
| token = os.environ.get("HF_TOKEN") | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=dtype, | |
| token=token | |
| ) | |
| pipe = pipe.to("cuda") | |
| def process_image_and_text(image, scale, seed, text): | |
| set_seed(seed) | |
| image = resize_and_add_margin(image, target_size=512) | |
| image_list = [image] | |
| # Dynamically set attention processors using user-specified scale | |
| blended_attn_procs = {} | |
| for name, _ in pipe.transformer.attn_processors.items(): | |
| if "single" in name: | |
| processor = FluxBlendedAttnProcessor2_0(3072, ba_scale=float(scale), num_ref=1) | |
| processor = processor.to(device="cuda", dtype=dtype) | |
| blended_attn_procs[name] = processor | |
| else: | |
| blended_attn_procs[name] = pipe.transformer.attn_processors[name] | |
| pipe.transformer.set_attn_processor(blended_attn_procs) | |
| model_path = hf_hub_download( | |
| repo_id="WonwoongCho/IT-Blender", | |
| filename="FLUX/it-blender.bin", | |
| token=token | |
| ) | |
| pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device) | |
| key_changed_blended_attn_weights = {} | |
| for key, value in pretrained_blended_attn_weights.items(): | |
| block_idx = int(key.split(".")[0]) - 21 | |
| k_or_v = key.split("_")[2] | |
| changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight' | |
| key_changed_blended_attn_weights[changed_key] = value.to(dtype) | |
| missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False) | |
| out = pipe( | |
| prompt=text, | |
| height=512, | |
| width=512, | |
| max_sequence_length=256, | |
| generator=torch.Generator().manual_seed(seed), | |
| it_blender_image=image_list | |
| ).images[0] | |
| return out | |
| def get_samples(): | |
| sample_list = [ | |
| { | |
| "image": "assets/0.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a monster cartoon character, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/1.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of an owl cartoon character, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/2.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a monster cartoon character, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character1.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character2.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character3.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/graphic1.jpg", | |
| "scale": 0.7, | |
| "seed": 42, | |
| "text": "A photo of a woman, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/sneakers1.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of sneakers, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/product1.jpg", | |
| "scale": 0.8, | |
| "seed": 42, | |
| "text": "A photo of a motorcycle, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/art1.jpg", | |
| "scale": 0.8, | |
| "seed": 42, | |
| "text": "A photo of Eiffel Tower, imaginative, creative, design", | |
| } | |
| ] | |
| return [ | |
| [ | |
| Image.open(sample["image"]).resize((512, 512)), | |
| sample["scale"], | |
| sample["seed"], | |
| sample["text"], | |
| ] | |
| for sample in sample_list | |
| ] | |
| header = """ | |
| # 💡 IT-Blender / FLUX | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/pdf/2506.24085"><img src="https://img.shields.io/badge/ArXiv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://imagineforme.github.io/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-ITBlender-yellow"></a> | |
| <a href="https://github.com/WonwoongCho/IT-Blender"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| </div> | |
| """ | |
| def create_app(): | |
| with gr.Blocks() as app: | |
| gr.Markdown(header, elem_id="header") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant="panel", elem_classes="inputPanel"): | |
| original_image_input = gr.Image( | |
| type="pil", label="Condition image", width=300, elem_id="input" | |
| ) | |
| scale_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="Scale (recommended range: 0.5-0.8; the higher, the stronger effect of the reference image)") | |
| seed_input = gr.Number(value=42, label="Seed", precision=0) | |
| text_input = gr.Textbox( | |
| lines=2, | |
| label="Text prompt", | |
| elem_id="text" | |
| ) | |
| submit_btn = gr.Button("Run", elem_id="submit_btn") | |
| with gr.Column(variant="panel", elem_classes="outputPanel"): | |
| output_image = gr.Image(type="pil", elem_id="output") | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=get_samples(), | |
| inputs=[original_image_input, scale_input, seed_input, text_input], | |
| label="Examples", | |
| ) | |
| submit_btn.click( | |
| fn=process_image_and_text, | |
| inputs=[original_image_input, scale_input, seed_input, text_input], | |
| outputs=output_image, | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| demo = create_app() | |
| demo.launch(debug=True, ssr_mode=False) |