flux-lora / app.py
multimodalart's picture
multimodalart HF Staff
Update app.py
701c4ad verified
raw
history blame
4.13 kB
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pipe.to("cuda")
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index
)
@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=f"{prompt} {trigger_word}",
#negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
#cross_attention_kwargs={"scale": lora_scale},
).images[0]
# Unload LoRA weights
pipe.unload_lora_weights()
return image
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# FLUX.1 LoRA the Explorer")
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1):
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
with gr.Column(scale=1):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=2
)
with gr.Column(scale=2):
result = gr.Image(label="Generated Image")
with gr.Row():
#with gr.Column():
#prompt_title = gr.Markdown("### Click on a LoRA in the gallery to select it")
#negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry")
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=1)
gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])
generate_button.click(
fn=run_lora,
inputs=[prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale],
outputs=[result]
)
app.queue()
app.launch()