Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import spaces | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import FluxPipeline | |
| from src.attention_processor import FluxBlendedAttnProcessor2_0 | |
| from src.utils_sample import set_seed, resize_and_add_margin | |
| import os | |
| dtype = torch.bfloat16 | |
| token = os.environ.get("HF_TOKEN") | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=dtype, | |
| token=token | |
| ) | |
| pipe = pipe.to("cuda") | |
| def process_image_and_text(image, text, seed, scale): | |
| set_seed(seed) | |
| image = resize_and_add_margin(image, target_size=512) | |
| image_list = [image] | |
| # Dynamically set attention processors using user-specified scale | |
| blended_attn_procs = {} | |
| for name, _ in pipe.transformer.attn_processors.items(): | |
| if "single" in name: | |
| processor = FluxBlendedAttnProcessor2_0(3072, ba_scale=float(scale), num_ref=1) | |
| processor = processor.to(device="cuda", dtype=dtype) | |
| blended_attn_procs[name] = processor | |
| else: | |
| blended_attn_procs[name] = pipe.transformer.attn_processors[name] | |
| pipe.transformer.set_attn_processor(blended_attn_procs) | |
| out = pipe( | |
| prompt=text, | |
| height=512, | |
| width=512, | |
| max_sequence_length=256, | |
| generator=torch.Generator().manual_seed(seed), | |
| it_blender_image=image_list | |
| ).images[0] | |
| return out | |
| def get_samples(): | |
| sample_list = [ | |
| { | |
| "image": "assets/0.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a monster cartoon character, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/1.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of an owl cartoon character, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/2.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character1.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character2.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/character3.jpg", | |
| "scale": 0.6, | |
| "seed": 42, | |
| "text": "A photo of a dragon, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/graphic1.jpg", | |
| "scale": 0.7, | |
| "seed": 42, | |
| "text": "A photo of a woman, imaginative, creative, design", | |
| }, | |
| { | |
| "image": "assets/product1.jpg", | |
| "scale": 0.8, | |
| "seed": 42, | |
| "text": "A photo of a motorcycle, imaginative, creative, design", | |
| } | |
| ] | |
| return [ | |
| [ | |
| Image.open(sample["image"]).resize((512, 512)), | |
| sample["scale"], | |
| sample["seed"], | |
| sample["text"], | |
| ] | |
| for sample in sample_list | |
| ] | |
| header = """ | |
| # 💡 IT-Blender / FLUX | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ArXiv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://imagineforme.github.io/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-ITBlender-yellow"></a> | |
| <a href="https://github.com/WonwoongCho/IT-Blender"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| </div> | |
| """ | |
| def create_app(): | |
| with gr.Blocks() as app: | |
| gr.Markdown(header, elem_id="header") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant="panel", elem_classes="inputPanel"): | |
| original_image = gr.Image( | |
| type="pil", label="Condition Image", width=300, elem_id="input" | |
| ) | |
| scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.6, label="Guidance Scale") | |
| seed = gr.Number(value=42, label="seed", precision=0) | |
| text = gr.Textbox( | |
| lines=2, | |
| label="Text Prompt", | |
| value="A photo of a dragon, imaginative, creative, design", | |
| elem_id="text" | |
| ) | |
| submit_btn = gr.Button("Run", elem_id="submit_btn") | |
| with gr.Column(variant="panel", elem_classes="outputPanel"): | |
| output_image = gr.Image(type="pil", elem_id="output") | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=get_samples(), | |
| inputs=[original_image, text, seed, scale], | |
| label="Examples", | |
| ) | |
| submit_btn.click( | |
| fn=process_image_and_text, | |
| inputs=[original_image, text, seed, scale], | |
| outputs=output_image, | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| demo = create_app() | |
| demo.launch(debug=True, ssr_mode=False) |