shemayons's picture
Update app.py
4f41a04 verified
import random
import torch
import gradio as gr
from diffusers import (
FluxPipeline,
DPMSolverMultistepScheduler,
DPMSolverSDEScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDIMScheduler,
)
# -----------------------------------------------------------------------------
# Pipeline loading helpers
# -----------------------------------------------------------------------------
def _load_pipe(hf_token: str | None = None) -> FluxPipeline:
"""Load the FLUX pipeline once and keep it in memory.
Args:
hf_token: Optional Hugging Face token if the model is gated/private.
Returns:
A fully‑initialised FluxPipeline with LoRA fused and memory‑saving
features enabled.
"""
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
use_auth_token=hf_token or None,
)
# Memory optimisations ----------------------------------------------------
pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing()
# LoRA --------------------------------------------------------------------
pipe.load_lora_weights(
"kudzueye/boreal-flux-dev-v2", weight_name="boreal-v2.safetensors"
)
pipe.fuse_lora(lora_scale=0.8)
return pipe
# Keep a single global instance to avoid re‑loading on every request
_pipe: FluxPipeline | None = None
def _get_pipe(hf_token: str | None = None) -> FluxPipeline:
global _pipe
if _pipe is None:
_pipe = _load_pipe(hf_token)
return _pipe
# -----------------------------------------------------------------------------
# Scheduler mapping
# -----------------------------------------------------------------------------
SCHED_MAP = {
"DPM++ 2M Karras": DPMSolverMultistepScheduler,
"DPM++ SDE Karras": DPMSolverSDEScheduler,
"Euler": EulerDiscreteScheduler,
"Euler a": EulerAncestralDiscreteScheduler,
"Heun": HeunDiscreteScheduler,
"DDIM": DDIMScheduler,
}
# -----------------------------------------------------------------------------
# Inference function
# -----------------------------------------------------------------------------
def query(
prompt: str,
negative_prompt: str,
steps: int,
cfg_scale: float,
sampler: str,
seed: int,
strength: float, # kept for future img2img support
hf_token: str,
):
"""Run the generation and return a PIL image + the seed actually used."""
pipe = _get_pipe(hf_token or None)
# Replace scheduler if the user selected a different sampler
SchedulerCls = SCHED_MAP.get(sampler, DPMSolverMultistepScheduler)
if not isinstance(pipe.scheduler, SchedulerCls):
pipe.scheduler = SchedulerCls.from_config(pipe.scheduler.config)
# Handle seed
if seed == -1:
seed = random.randint(0, 1_000_000_000)
generator = torch.Generator(device=pipe.device).manual_seed(seed)
# Run inference
with torch.no_grad():
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
generator=generator,
height=512,
width=512,
)
return result.images[0], str(seed)
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
CSS = """
#app-container {
max-width: 600px;
margin-left: auto;
margin-right: auto;
}
#title-container {
display: flex;
align-items: center;
justify-content: center;
}
#title-icon {
width: 32px;
height: auto;
margin-right: 10px;
}
#title-text {
font-size: 24px;
font-weight: bold;
}
"""
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=CSS) as app:
gr.HTML(
"""
<center>
<div id="title-container">
<h1 id="title-text">Text-to-Image Generator App</h1>
</div>
</center>
"""
)
with gr.Column(elem_id="app-container"):
with gr.Row():
with gr.Column(elem_id="prompt-container"):
with gr.Row():
txt_prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt here",
lines=2,
elem_id="prompt-text-input",
)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
neg_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What should not be in the image",
value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
lines=3,
elem_id="negative-prompt-text-input",
)
steps_in = gr.Slider(
label="Sampling steps", value=35, minimum=1, maximum=100, step=1
)
cfg_in = gr.Slider(
label="CFG Scale", value=7, minimum=1, maximum=20, step=1
)
sampler_in = gr.Radio(
label="Sampling method",
value="DPM++ 2M Karras",
choices=list(SCHED_MAP.keys()),
)
strength_in = gr.Slider(
label="Strength", value=0.7, minimum=0, maximum=1, step=0.001
)
seed_in = gr.Slider(
label="Seed", value=-1, minimum=-1, maximum=1_000_000_000, step=1
)
api_key_in = gr.Textbox(
label="Hugging Face API Key (required for private models)",
placeholder="Enter your Hugging Face API Key here",
type="password",
elem_id="api-key",
)
with gr.Row():
run_button = gr.Button("Run", variant="primary", elem_id="gen-button")
with gr.Row():
img_out = gr.Image(type="pil", label="Image Output", elem_id="gallery")
seed_out = gr.Textbox(label="Seed Used", elem_id="seed-output")
run_button.click(
fn=query,
inputs=[
txt_prompt,
neg_prompt,
steps_in,
cfg_in,
sampler_in,
seed_in,
strength_in,
api_key_in,
],
outputs=[img_out, seed_out],
)
app.launch(show_api=True, share=False)