import os import io import gc import uuid import json import base64 import random import zipfile import threading import concurrent.futures from pathlib import Path from typing import List, Optional import spaces import numpy as np import torch from PIL import Image from gradio import Server from fastapi import Request, UploadFile, File, Form from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, StreamingResponse from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2 # --- App Configuration & Directories --- app = Server() BASE_DIR = Path(__file__).resolve().parent STATIC_DIR = BASE_DIR / "static" OUTPUT_DIR = BASE_DIR / "outputs" EXAMPLES_DIR = BASE_DIR / "examples" STATIC_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 dtype = torch.bfloat16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): print("current device:", torch.cuda.current_device()) print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) DEVICE_LABEL = torch.cuda.get_device_name(torch.cuda.current_device()).lower() else: DEVICE_LABEL = str(device).lower() # --- Model Loading --- print("Loading 4B Distilled model (Standard VAE)...") pipe_standard = Flux2KleinPipeline.from_pretrained( "black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, ).to(device) pipe_standard.enable_model_cpu_offload() print("Loading Small Decoder VAE...") vae_small = AutoencoderKLFlux2.from_pretrained( "black-forest-labs/FLUX.2-small-decoder", torch_dtype=dtype, ).to(device) print("Loading 4B Distilled model (Small Decoder VAE)...") pipe_small_decoder = Flux2KleinPipeline.from_pretrained( "black-forest-labs/FLUX.2-klein-4B", vae=vae_small, torch_dtype=dtype, ).to(device) pipe_small_decoder.enable_model_cpu_offload() pipe_lock_standard = threading.Lock() pipe_lock_small = threading.Lock() # --- Utility Functions --- def calc_dimensions(pil_img: Image.Image): iw, ih = pil_img.size aspect = iw / ih if aspect >= 1: new_width = 1024 new_height = int(round(1024 / aspect)) else: new_height = 1024 new_width = int(round(1024 * aspect)) new_width = max(256, min(1024, round(new_width / 8) * 8)) new_height = max(256, min(1024, round(new_height / 8) * 8)) return new_width, new_height def parse_and_resize_images(image_paths: List[str], width: int, height: int): if not image_paths: return None resized = [] for path in image_paths: try: img = Image.open(path).convert("RGB") resized.append(img.resize((width, height), Image.LANCZOS)) except Exception as e: print(f"Skipping invalid image: {e}") return resized if resized else None def run_pipeline(pipe, lock, kwargs, seed): with lock: gen = torch.Generator(device="cpu").manual_seed(seed) result = pipe(**kwargs, generator=gen).images[0] return result def save_image(img: Image.Image, prefix: str = "output") -> str: filename = f"{prefix}_{uuid.uuid4().hex}.png" path = OUTPUT_DIR / filename img.save(path, format="PNG") return filename # --- Inference Function --- @spaces.GPU(duration=120) def infer( prompt: str, image_paths: List[str] = None, seed: int = 42, randomize_seed: bool = False, width: int = 1024, height: int = 1024, num_inference_steps: int = 4, guidance_scale: float = 1.0, ): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if not prompt or not prompt.strip(): raise ValueError("Please enter a prompt.") if randomize_seed: seed = random.randint(0, MAX_SEED) image_list = None if image_paths and len(image_paths) > 0: try: first_pil = Image.open(image_paths[0]).convert("RGB") width, height = calc_dimensions(first_pil) image_list = parse_and_resize_images(image_paths, width, height) except Exception as e: print(f"Error processing upload: {e}") width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8)) height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8)) shared_kwargs = dict( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) if image_list is not None: shared_kwargs["image"] = image_list with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed) future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed) concurrent.futures.wait( [future_std, future_small], return_when=concurrent.futures.ALL_COMPLETED, ) out_standard = future_std.result() out_small = future_small.result() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return out_standard, out_small, seed # --- FastAPI Endpoints --- def get_example_items(): return [ { "urls": ["/example-file/I1.jpg", "/example-file/I2.jpg"], "prompt": "Make her wear these glasses in Image 2." }, { "urls": ["/example-file/1.jpg"], "prompt": "Change the weather to stormy." }, { "urls": ["/example-file/2.jpg"], "prompt": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition." }, { "urls": ["/example-file/3.jpg"], "prompt": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent." }, { "urls": ["/example-file/4.jpg"], "prompt": "Make the texture high-resolution." } ] @app.get("/example-file/{filename}") async def example_file(filename: str): path = EXAMPLES_DIR / filename if not path.exists(): return JSONResponse({"error": "Example not found"}, status_code=404) return FileResponse(path) @app.get("/download/{filename}") async def download_file(filename: str): path = OUTPUT_DIR / filename if not path.exists(): return JSONResponse({"error": "File not found"}, status_code=404) return FileResponse(path, filename=filename, media_type="image/png") @app.get("/api/download-zip") async def download_zip(std: str, small: str): """Packages both generated images into a single ZIP file and streams it.""" std_name = Path(std).name small_name = Path(small).name std_path = OUTPUT_DIR / std_name small_path = OUTPUT_DIR / small_name if not std_path.exists() or not small_path.exists(): return JSONResponse({"error": "Generated files not found"}, status_code=404) memory_file = io.BytesIO() with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(std_path, arcname=f"Standard_Decoder_{std_name}") zf.write(small_path, arcname=f"Small_Decoder_{small_name}") memory_file.seek(0) return StreamingResponse( memory_file, media_type="application/zip", headers={"Content-Disposition": f"attachment; filename=Flux2_Comparison_{uuid.uuid4().hex[:6]}.zip"} ) @app.post("/api/compare") async def compare_images( prompt: str = Form(...), seed: str = Form("0"), randomize_seed: str = Form("true"), width: str = Form("1024"), height: str = Form("1024"), steps: str = Form("4"), guidance: str = Form("1.0"), images: Optional[List[UploadFile]] = File(None), ): temp_paths = [] try: image_paths = [] if images: for upload in images: if not upload.filename: continue suffix = Path(upload.filename).suffix or ".png" temp_path = OUTPUT_DIR / f"upload_{uuid.uuid4().hex}{suffix}" content = await upload.read() with open(temp_path, "wb") as f: f.write(content) temp_paths.append(str(temp_path)) image_paths.append(str(temp_path)) result_std, result_small, used_seed = infer( prompt=prompt, image_paths=image_paths, seed=int(seed), randomize_seed=(randomize_seed.lower() == "true"), width=int(width), height=int(height), num_inference_steps=int(steps), guidance_scale=float(guidance), ) std_filename = save_image(result_std, prefix="std") small_filename = save_image(result_small, prefix="small") return JSONResponse({ "success": True, "seed": used_seed, "std_url": f"/download/{std_filename}", "small_url": f"/download/{small_filename}", "std_filename": std_filename, "small_filename": small_filename, "device": DEVICE_LABEL, }) except Exception as e: return JSONResponse({"success": False, "error": str(e)}, status_code=500) finally: for p in temp_paths: if os.path.exists(p): os.remove(p) # --- Frontend --- @app.get("/", response_class=HTMLResponse) async def homepage(request: Request): examples = get_example_items() examples_json = json.dumps(examples) return f""" Flux.2-4B-Decoder-Comparator
Flux.2-4B VAE Decoder Comparator

Standard vs. Small Decoder

Upload an image, enter a prompt, and use the slider to compare outputs in real-time.

Settings
Click or Drag & Drop images here
Execution Log
[{DEVICE_LABEL}]System Ready...
Comparison View
Results will appear here
Standard Decoder Small Decoder
Standard Decoder
Small Decoder
Running both decoders...

Examples

""" app.launch()