klein-9B / app.py
fdsgsfjsfg's picture
Add error handling to expose exceptions in API responses
d0e80d9 verified
"""
FLUX.2 Klein 9B
"""
import gradio as gr
import numpy as np
import random
import spaces
import torch
import traceback
from diffusers import Flux2KleinPipeline
from PIL import Image
dtype = torch.bfloat16
MAX_SEED = np.iinfo(np.int32).max
REPO_ID = "black-forest-labs/FLUX.2-klein-9B"
print("Loading...")
pipe = Flux2KleinPipeline.from_pretrained(REPO_ID, torch_dtype=dtype).to("cuda")
print("Model loaded!")
@spaces.GPU(duration=85)
def infer(
prompt: str,
input_images=None,
mask_image=None,
seed: int = 42,
randomize_seed: bool = True,
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 1.0,
):
try:
if isinstance(seed, str): seed = int(seed)
if isinstance(randomize_seed, str): randomize_seed = randomize_seed.lower() == "true"
width = int(float(width))
height = int(float(height))
num_inference_steps = int(float(num_inference_steps))
if isinstance(guidance_scale, str): guidance_scale = float(guidance_scale)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": generator,
}
print(f"input_images type: {type(input_images)}, value: {input_images}")
if input_images is not None and len(input_images) > 0:
imgs = []
for item in input_images:
if isinstance(item, tuple):
imgs.append(item[0])
elif isinstance(item, Image.Image):
imgs.append(item)
else:
print(f"Unknown item type: {type(item)}, value: {item}")
imgs.append(item)
pipe_kwargs["image"] = imgs
print(f"pipe_kwargs keys: {list(pipe_kwargs.keys())}")
print(f"image count: {len(pipe_kwargs.get('image', []))}")
result_image = pipe(**pipe_kwargs).images[0]
return result_image, seed
except Exception as e:
tb = traceback.format_exc()
print(f"ERROR: {e}")
print(tb)
raise gr.Error(f"{type(e).__name__}: {e}")
with gr.Blocks() as demo:
gr.Markdown("# FLUX.2 Klein 9B")
with gr.Row():
prompt = gr.Text(label="Prompt", value="clean background, no watermark")
run_btn = gr.Button("Run")
input_images = gr.Gallery(label="Input Image(s)", type="pil")
mask_image = gr.Image(type="pil", label="Mask")
seed = gr.Number(label="Seed", value=42)
randomize_seed = gr.Checkbox(label="Random seed", value=True)
width = gr.Number(label="Width", value=1024)
height = gr.Number(label="Height", value=1024)
steps = gr.Number(label="Steps", value=4)
guidance = gr.Number(label="Guidance", value=1.0)
result = gr.Image(label="Result")
run_btn.click(
infer,
inputs=[prompt, input_images, mask_image, seed, randomize_seed, width, height, steps, guidance],
outputs=[result, seed],
api_name="generate",
)
demo.launch()