| import gradio as gr |
| import requests |
| from PIL import Image |
| import io |
| import os |
| from fal_client import submit |
|
|
| def set_fal_key(api_key): |
| os.environ["FAL_KEY"] = api_key |
| return "FAL API key set successfully!" |
|
|
| def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed): |
| set_fal_key(api_key) |
| |
| arguments = { |
| "prompt": prompt, |
| "image_size": image_size, |
| "num_inference_steps": num_inference_steps, |
| "num_images": num_images, |
| } |
|
|
| if model == "Flux Pro": |
| arguments["guidance_scale"] = guidance_scale |
| arguments["safety_tolerance"] = safety_tolerance |
| fal_model = "fal-ai/flux-pro" |
| elif model == "Flux Dev": |
| arguments["guidance_scale"] = guidance_scale |
| arguments["enable_safety_checker"] = enable_safety_checker |
| fal_model = "fal-ai/flux/dev" |
| else: |
| arguments["enable_safety_checker"] = enable_safety_checker |
| fal_model = "fal-ai/flux/schnell" |
|
|
| if seed != -1: |
| arguments["seed"] = seed |
|
|
| try: |
| handler = submit(fal_model, arguments=arguments) |
| result = handler.get() |
| images = [] |
| for img_info in result["images"]: |
| img_url = img_info["url"] |
| img_response = requests.get(img_url) |
| img = Image.open(io.BytesIO(img_response.content)) |
| images.append(img) |
| return images |
| except Exception as e: |
| return [Image.new('RGB', (512, 512), color='black')] |
|
|
| def update_visible_components(model): |
| if model == "Flux Pro": |
| return [ |
| gr.update(visible=True, value=28), |
| gr.update(visible=True, value=3.5), |
| gr.update(visible=True, value="2"), |
| gr.update(visible=False) |
| ] |
| elif model == "Flux Dev": |
| return [ |
| gr.update(visible=True, value=28), |
| gr.update(visible=True, value=3.5), |
| gr.update(visible=False), |
| gr.update(visible=True, value=True) |
| ] |
| else: |
| return [ |
| gr.update(visible=True, value=4), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=True, value=True) |
| ] |
|
|
| with gr.Blocks(theme='bethecloud/storj_theme') as demo: |
| gr.HTML(""" |
| <h1 align="center">FLUX.1 Image Generation</h1> |
| <p align="center"> |
| <a href="https://blackforestlabs.ai/" target="_blank">[Black Forest Labs]</a> |
| <a href="https://blackforestlabs.ai/announcing-black-forest-labs/" target="_blank">[Blog]</a> |
| <a href="https://fal.ai/models/fal-ai/flux-pro" target="_blank">[FLUX.1 [pro] Model FAL]</a> |
| <a href="https://fal.ai/dashboard/keys" target="_blank">[GET YOUR API KEY HERE]</a> |
| </p> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| api_key = gr.Textbox(type="password", label="FAL API Key") |
| model = gr.Dropdown( |
| label="Model", |
| choices=["Flux Pro", "Flux Dev", "Flux Schnell"], |
| value="Flux Pro" |
| ) |
| prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Add your prompt here") |
| image_size = gr.Dropdown( |
| choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], |
| label="Image Size", |
| value="landscape_4_3" |
| ) |
| |
| with gr.Accordion("Advanced settings", open=False): |
| num_inference_steps = gr.Slider(1, 100, 28, step=1, label="Number of Inference Steps") |
| guidance_scale = gr.Slider(0, 20, 3.5, step=0.1, label="Guidance Scale") |
| num_images = gr.Slider(1, 10, 1, step=1, label="Number of Images") |
| safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2") |
| enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) |
| seed = gr.Number(label="Seed", value=-1) |
|
|
| generate_btn = gr.Button("Generate Image") |
|
|
| with gr.Column(scale=1): |
| output_gallery = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False) |
|
|
| model.change(update_visible_components, inputs=[model], outputs=[num_inference_steps, guidance_scale, safety_tolerance, enable_safety_checker]) |
|
|
| generate_btn.click( |
| fn=generate_image, |
| inputs=[ |
| api_key, model, prompt, image_size, num_inference_steps, |
| guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed |
| ], |
| outputs=[output_gallery] |
| ) |
|
|
| demo.launch() |