FLUX.2_klein_4B / app.py
yingzhac-research
Use diffusers git build for Flux2 pipeline
655e5a5
import math
import random
from threading import Lock
import gradio as gr
import spaces
import torch
try:
from diffusers import Flux2KleinPipeline
except ImportError:
from diffusers.pipelines.flux2 import Flux2KleinPipeline
from PIL import Image
MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"
MAX_SEED = 2**32 - 1
MAX_MEGAPIXELS = 4.0
SIZE_MULTIPLE = 16
MIN_SIDE = 256
pipe = None
pipe_lock = Lock()
def _round_to_multiple(value, multiple):
return int(value // multiple) * multiple
def _clamp_size(width, height):
width = max(MIN_SIDE, _round_to_multiple(width, SIZE_MULTIPLE))
height = max(MIN_SIDE, _round_to_multiple(height, SIZE_MULTIPLE))
max_pixels = int(MAX_MEGAPIXELS * 1_000_000)
current_pixels = width * height
if current_pixels > max_pixels:
scale = math.sqrt(max_pixels / float(current_pixels))
width = max(MIN_SIDE, _round_to_multiple(width * scale, SIZE_MULTIPLE))
height = max(MIN_SIDE, _round_to_multiple(height * scale, SIZE_MULTIPLE))
return int(width), int(height)
def _load_pipe():
global pipe
if pipe is not None:
return pipe
with pipe_lock:
if pipe is None:
print("Loading FLUX.2 klein 4B pipeline...")
pipe = Flux2KleinPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
).to("cuda")
print("Pipeline loaded.")
return pipe
@spaces.GPU(duration=120)
def generate_image(
prompt,
input_image,
width,
height,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
width = int(width)
height = int(height)
if input_image is not None:
width, height = _clamp_size(*input_image.size)
input_image = input_image.resize((width, height), resample=Image.LANCZOS)
else:
width, height = _clamp_size(width, height)
generator = torch.Generator(device="cuda").manual_seed(int(seed))
pipe = _load_pipe()
pipe_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"guidance_scale": float(guidance_scale),
"num_inference_steps": int(num_inference_steps),
"generator": generator,
}
if input_image is not None:
pipe_kwargs["image"] = input_image
with torch.inference_mode():
image = pipe(**pipe_kwargs).images[0]
return image, int(seed), f"{width}x{height}"
def suggest_dimensions_from_image(input_image):
if input_image is None:
return 1024, 1024
width, height = _clamp_size(*input_image.size)
return width, height
with gr.Blocks(title="FLUX.2 [klein] 4B") as demo:
gr.Markdown(
"""
# FLUX.2 [klein] 4B
Generate images with FLUX.2 klein 4B. Optionally upload a reference image for editing.
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want...",
lines=3,
)
input_image = gr.Image(
label="Input image (optional)",
type="pil",
)
with gr.Row():
width = gr.Slider(
minimum=MIN_SIDE,
maximum=2048,
value=1024,
step=SIZE_MULTIPLE,
label="Width",
)
height = gr.Slider(
minimum=MIN_SIDE,
maximum=2048,
value=1024,
step=SIZE_MULTIPLE,
label="Height",
)
num_inference_steps = gr.Slider(
minimum=1,
maximum=12,
value=4,
step=1,
label="Inference steps",
info="Distilled model; 4 steps is a good default",
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=4.0,
value=1.0,
step=0.1,
label="Guidance scale",
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0,
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=False,
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Output")
used_seed = gr.Number(label="Seed used", precision=0, interactive=False)
final_size = gr.Textbox(label="Final size", interactive=False)
gr.Examples(
examples=[
["A cinematic portrait of a snow leopard, rim lighting, ultra detailed"],
["A cozy cabin by a lake at sunrise, mist, soft light"],
["Futuristic city skyline with flying cars, dusk, neon glow"],
],
inputs=[prompt],
cache_examples=False,
)
input_image.change(
fn=suggest_dimensions_from_image,
inputs=[input_image],
outputs=[width, height],
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt,
input_image,
width,
height,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
],
outputs=[output_image, used_seed, final_size],
)
prompt.submit(
fn=generate_image,
inputs=[
prompt,
input_image,
width,
height,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
],
outputs=[output_image, used_seed, final_size],
)
if __name__ == "__main__":
demo.launch()