| import os |
| import gradio as gr |
| from gradio_client import Client, handle_file |
| import torch |
| import spaces |
| from diffusers import FluxPipeline |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| if torch.cuda.is_available(): |
| torch_dtype = torch.bfloat16 |
| else: |
| torch_dtype = torch.float32 |
|
|
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| def load_models(): |
| pipe = FluxPipeline.from_pretrained( |
| "X-ART/LeX-FLUX", |
| torch_dtype=torch.bfloat16 |
| ) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| pipe.to("cuda") |
| |
| return pipe |
|
|
| def prompt_enhance(client, image_caption, text_caption): |
| combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption") |
| return combined_caption, enhanced_caption |
| |
|
|
| pipe = load_models() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @spaces.GPU(duration=60) |
| def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale): |
| |
| pipe.enable_model_cpu_offload() |
| """Generate image using LeX-FLUX""" |
| |
| |
| |
| generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None |
|
|
| image = pipe( |
| enhanced_caption, |
| height=1024, |
| width=1024, |
| guidance_scale=3.5, |
| output_type="pil", |
| num_inference_steps=28, |
| max_sequence_length=512, |
| generator=torch.Generator("cpu").manual_seed(0) |
| ).images[0] |
|
|
| print(image) |
| |
| |
| |
| |
| return image |
|
|
| |
| def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer): |
| """Run the complete pipeline from captions to final image""" |
| combined_caption = f"{image_caption}, with the text on it: {text_caption}." |
| |
| if enable_enhancer: |
| |
| client = Client("stzhao/LeX-Enhancer") |
| combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption) |
| print(f"enhanced caption:\n{enhanced_caption}") |
| else: |
| enhanced_caption = combined_caption |
| |
| image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale) |
| |
| return image, combined_caption, enhanced_caption |
|
|
| |
| with gr.Blocks() as demo: |
| |
| gr.Markdown("# LeX-Enhancer & LeX-FLUX Demo") |
| gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/") |
| gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-FLUX") |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_caption = gr.Textbox( |
| lines=2, |
| label="Image Caption", |
| placeholder="Describe the visual content of the image", |
| value="A picture of a group of people gathered in front of a world map" |
| ) |
| text_caption = gr.Textbox( |
| lines=2, |
| label="Text Caption", |
| placeholder="Describe any text that should appear in the image", |
| value="\"Communicate\" in purple, \"Execute\" in yellow" |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| enable_enhancer = gr.Checkbox( |
| label="Enable LeX-Enhancer", |
| value=True, |
| info="When enabled, the caption will be enhanced before image generation" |
| ) |
| seed = gr.Slider( |
| minimum=0, |
| maximum=100000, |
| value=0, |
| step=1, |
| label="Seed (0 for random)" |
| ) |
| num_inference_steps = gr.Slider( |
| minimum=20, |
| maximum=100, |
| value=40, |
| step=1, |
| label="Number of Inference Steps" |
| ) |
| guidance_scale = gr.Slider( |
| minimum=1.0, |
| maximum=10.0, |
| value=7.5, |
| step=0.1, |
| label="Guidance Scale" |
| ) |
| |
| submit_btn = gr.Button("Generate", variant="primary") |
| |
| with gr.Column(): |
| output_image = gr.Image(label="Generated Image") |
| combined_caption_box = gr.Textbox( |
| label="Combined Caption", |
| interactive=False |
| ) |
| enhanced_caption_box = gr.Textbox( |
| label="Enhanced Caption" if enable_enhancer.value else "Final Caption", |
| interactive=False, |
| lines=5 |
| ) |
| |
| |
| examples = [ |
| ["A modern office workspace", "\"Innovation\" in bold blue letters at the center"], |
| ["A beach sunset scene", "\"Relax\" in cursive white text in the corner"], |
| ["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"] |
| ] |
| gr.Examples( |
| examples=examples, |
| inputs=[image_caption, text_caption], |
| label="Example Inputs" |
| ) |
| |
| |
| def update_caption_label(enable_enhancer): |
| return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption") |
| |
| enable_enhancer.change( |
| fn=update_caption_label, |
| inputs=enable_enhancer, |
| outputs=enhanced_caption_box |
| ) |
| |
| submit_btn.click( |
| fn=run_pipeline, |
| inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer], |
| outputs=[output_image, combined_caption_box, enhanced_caption_box] |
| ) |
|
|
| |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |