""" FLUX.2 Klein 9B """ import gradio as gr import numpy as np import random import spaces import torch import traceback from diffusers import Flux2KleinPipeline from PIL import Image dtype = torch.bfloat16 MAX_SEED = np.iinfo(np.int32).max REPO_ID = "black-forest-labs/FLUX.2-klein-9B" print("Loading...") pipe = Flux2KleinPipeline.from_pretrained(REPO_ID, torch_dtype=dtype).to("cuda") print("Model loaded!") @spaces.GPU(duration=85) def infer( prompt: str, input_images=None, mask_image=None, seed: int = 42, randomize_seed: bool = True, width: int = 1024, height: int = 1024, num_inference_steps: int = 4, guidance_scale: float = 1.0, ): try: if isinstance(seed, str): seed = int(seed) if isinstance(randomize_seed, str): randomize_seed = randomize_seed.lower() == "true" width = int(float(width)) height = int(float(height)) num_inference_steps = int(float(num_inference_steps)) if isinstance(guidance_scale, str): guidance_scale = float(guidance_scale) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) pipe_kwargs = { "prompt": prompt, "height": height, "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "generator": generator, } print(f"input_images type: {type(input_images)}, value: {input_images}") if input_images is not None and len(input_images) > 0: imgs = [] for item in input_images: if isinstance(item, tuple): imgs.append(item[0]) elif isinstance(item, Image.Image): imgs.append(item) else: print(f"Unknown item type: {type(item)}, value: {item}") imgs.append(item) pipe_kwargs["image"] = imgs print(f"pipe_kwargs keys: {list(pipe_kwargs.keys())}") print(f"image count: {len(pipe_kwargs.get('image', []))}") result_image = pipe(**pipe_kwargs).images[0] return result_image, seed except Exception as e: tb = traceback.format_exc() print(f"ERROR: {e}") print(tb) raise gr.Error(f"{type(e).__name__}: {e}") with gr.Blocks() as demo: gr.Markdown("# FLUX.2 Klein 9B") with gr.Row(): prompt = gr.Text(label="Prompt", value="clean background, no watermark") run_btn = gr.Button("Run") input_images = gr.Gallery(label="Input Image(s)", type="pil") mask_image = gr.Image(type="pil", label="Mask") seed = gr.Number(label="Seed", value=42) randomize_seed = gr.Checkbox(label="Random seed", value=True) width = gr.Number(label="Width", value=1024) height = gr.Number(label="Height", value=1024) steps = gr.Number(label="Steps", value=4) guidance = gr.Number(label="Guidance", value=1.0) result = gr.Image(label="Result") run_btn.click( infer, inputs=[prompt, input_images, mask_image, seed, randomize_seed, width, height, steps, guidance], outputs=[result, seed], api_name="generate", ) demo.launch()