Spaces:
Running on Zero
Running on Zero
| """ | |
| FLUX.2 Klein 9B | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| import traceback | |
| from diffusers import Flux2KleinPipeline | |
| from PIL import Image | |
| dtype = torch.bfloat16 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| REPO_ID = "black-forest-labs/FLUX.2-klein-9B" | |
| print("Loading...") | |
| pipe = Flux2KleinPipeline.from_pretrained(REPO_ID, torch_dtype=dtype).to("cuda") | |
| print("Model loaded!") | |
| def infer( | |
| prompt: str, | |
| input_images=None, | |
| mask_image=None, | |
| seed: int = 42, | |
| randomize_seed: bool = True, | |
| width: int = 1024, | |
| height: int = 1024, | |
| num_inference_steps: int = 4, | |
| guidance_scale: float = 1.0, | |
| ): | |
| try: | |
| if isinstance(seed, str): seed = int(seed) | |
| if isinstance(randomize_seed, str): randomize_seed = randomize_seed.lower() == "true" | |
| width = int(float(width)) | |
| height = int(float(height)) | |
| num_inference_steps = int(float(num_inference_steps)) | |
| if isinstance(guidance_scale, str): guidance_scale = float(guidance_scale) | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| pipe_kwargs = { | |
| "prompt": prompt, | |
| "height": height, | |
| "width": width, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "generator": generator, | |
| } | |
| print(f"input_images type: {type(input_images)}, value: {input_images}") | |
| if input_images is not None and len(input_images) > 0: | |
| imgs = [] | |
| for item in input_images: | |
| if isinstance(item, tuple): | |
| imgs.append(item[0]) | |
| elif isinstance(item, Image.Image): | |
| imgs.append(item) | |
| else: | |
| print(f"Unknown item type: {type(item)}, value: {item}") | |
| imgs.append(item) | |
| pipe_kwargs["image"] = imgs | |
| print(f"pipe_kwargs keys: {list(pipe_kwargs.keys())}") | |
| print(f"image count: {len(pipe_kwargs.get('image', []))}") | |
| result_image = pipe(**pipe_kwargs).images[0] | |
| return result_image, seed | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| print(f"ERROR: {e}") | |
| print(tb) | |
| raise gr.Error(f"{type(e).__name__}: {e}") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# FLUX.2 Klein 9B") | |
| with gr.Row(): | |
| prompt = gr.Text(label="Prompt", value="clean background, no watermark") | |
| run_btn = gr.Button("Run") | |
| input_images = gr.Gallery(label="Input Image(s)", type="pil") | |
| mask_image = gr.Image(type="pil", label="Mask") | |
| seed = gr.Number(label="Seed", value=42) | |
| randomize_seed = gr.Checkbox(label="Random seed", value=True) | |
| width = gr.Number(label="Width", value=1024) | |
| height = gr.Number(label="Height", value=1024) | |
| steps = gr.Number(label="Steps", value=4) | |
| guidance = gr.Number(label="Guidance", value=1.0) | |
| result = gr.Image(label="Result") | |
| run_btn.click( | |
| infer, | |
| inputs=[prompt, input_images, mask_image, seed, randomize_seed, width, height, steps, guidance], | |
| outputs=[result, seed], | |
| api_name="generate", | |
| ) | |
| demo.launch() | |