import os import io import gc import uuid import json import base64 import random import zipfile import threading from pathlib import Path from typing import List, Optional import spaces import numpy as np import torch from PIL import Image from gradio import Server from fastapi import Request, UploadFile, File, Form from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, StreamingResponse from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2 # --- App Configuration & Directories --- app = Server() BASE_DIR = Path(__file__).resolve().parent STATIC_DIR = BASE_DIR / "static" OUTPUT_DIR = BASE_DIR / "outputs" EXAMPLES_DIR = BASE_DIR / "examples" STATIC_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 dtype = torch.bfloat16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): print("current device:", torch.cuda.current_device()) print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) DEVICE_LABEL = torch.cuda.get_device_name(torch.cuda.current_device()).lower() else: DEVICE_LABEL = str(device).lower() # --- Model Loading --- print("Loading Small Decoder VAE...") vae_small = AutoencoderKLFlux2.from_pretrained( "black-forest-labs/FLUX.2-small-decoder", torch_dtype=dtype, ).to(device) print("Loading 4B Distilled model (Small Decoder VAE)...") pipe = Flux2KleinPipeline.from_pretrained( "black-forest-labs/FLUX.2-klein-4B", vae=vae_small, torch_dtype=dtype, ).to(device) pipe.enable_model_cpu_offload() pipe_lock = threading.Lock() # --- Utility Functions --- def calc_dimensions(pil_img: Image.Image): iw, ih = pil_img.size aspect = iw / ih if aspect >= 1: new_width = 1024 new_height = int(round(1024 / aspect)) else: new_height = 1024 new_width = int(round(1024 * aspect)) new_width = max(256, min(1024, round(new_width / 8) * 8)) new_height = max(256, min(1024, round(new_height / 8) * 8)) return new_width, new_height def parse_and_resize_images(image_paths: List[str], width: int, height: int): if not image_paths: return None resized = [] for path in image_paths: try: img = Image.open(path).convert("RGB") resized.append(img.resize((width, height), Image.LANCZOS)) except Exception as e: print(f"Skipping invalid image: {e}") return resized if resized else None def run_pipeline(kwargs, seed): with pipe_lock: gen = torch.Generator(device="cpu").manual_seed(seed) result = pipe(**kwargs, generator=gen).images[0] return result def save_image(img: Image.Image, prefix: str = "output") -> str: filename = f"{prefix}_{uuid.uuid4().hex}.png" path = OUTPUT_DIR / filename img.save(path, format="PNG") return filename # --- Inference Function --- @spaces.GPU(duration=120) def infer( prompt: str, image_paths: List[str] = None, seed: int = 42, randomize_seed: bool = False, width: int = 1024, height: int = 1024, num_inference_steps: int = 4, guidance_scale: float = 1.0, ): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if not prompt or not prompt.strip(): raise ValueError("Please enter a prompt.") if randomize_seed: seed = random.randint(0, MAX_SEED) image_list = None if image_paths and len(image_paths) > 0: try: first_pil = Image.open(image_paths[0]).convert("RGB") width, height = calc_dimensions(first_pil) image_list = parse_and_resize_images(image_paths, width, height) except Exception as e: print(f"Error processing upload: {e}") width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8)) height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8)) kwargs = dict( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) if image_list is not None: kwargs["image"] = image_list result = run_pipeline(kwargs, seed) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return result, seed # --- FastAPI Endpoints --- def get_example_items(): return [ { "urls": ["/example-file/I1.jpg", "/example-file/I2.jpg"], "prompt": "Make her wear these glasses in Image 2." }, { "urls": ["/example-file/1.jpg"], "prompt": "Change the weather to stormy." }, { "urls": ["/example-file/2.jpg"], "prompt": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition." }, { "urls": ["/example-file/3.jpg"], "prompt": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent." }, { "urls": ["/example-file/4.jpg"], "prompt": "Make the texture high-resolution." } ] @app.get("/example-file/{filename}") async def example_file(filename: str): path = EXAMPLES_DIR / filename if not path.exists(): return JSONResponse({"error": "Example not found"}, status_code=404) return FileResponse(path) @app.get("/download/{filename}") async def download_file(filename: str): path = OUTPUT_DIR / filename if not path.exists(): return JSONResponse({"error": "File not found"}, status_code=404) return FileResponse(path, filename=filename, media_type="image/png") @app.post("/api/generate") async def generate_image( prompt: str = Form(...), seed: str = Form("0"), randomize_seed: str = Form("true"), width: str = Form("1024"), height: str = Form("1024"), steps: str = Form("4"), guidance: str = Form("1.0"), images: Optional[List[UploadFile]] = File(None), ): temp_paths = [] try: image_paths = [] if images: for upload in images: if not upload.filename: continue suffix = Path(upload.filename).suffix or ".png" temp_path = OUTPUT_DIR / f"upload_{uuid.uuid4().hex}{suffix}" content = await upload.read() with open(temp_path, "wb") as f: f.write(content) temp_paths.append(str(temp_path)) image_paths.append(str(temp_path)) result, used_seed = infer( prompt=prompt, image_paths=image_paths, seed=int(seed), randomize_seed=(randomize_seed.lower() == "true"), width=int(width), height=int(height), num_inference_steps=int(steps), guidance_scale=float(guidance), ) filename = save_image(result, prefix="output") return JSONResponse({ "success": True, "seed": used_seed, "url": f"/download/{filename}", "filename": filename, "device": DEVICE_LABEL, }) except Exception as e: return JSONResponse({"success": False, "error": str(e)}, status_code=500) finally: for p in temp_paths: if os.path.exists(p): os.remove(p) # --- Frontend --- @app.get("/", response_class=HTMLResponse) async def homepage(request: Request): examples = get_example_items() examples_json = json.dumps(examples) return f"""
Upload an image (optional) and enter a prompt to generate with the 4B distilled model.