import os import random from io import BytesIO from pathlib import Path from typing import Any import torch import uvicorn from diffusers import DiffusionPipeline from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import JSONResponse, Response from PIL import Image, UnidentifiedImageError MODEL_ID = os.getenv("MODEL_ID", "black-forest-labs/FLUX.2-klein-9b-fp8") LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "./models/flux2-klein-9b-fp8") HF_TOKEN = os.getenv("HF_TOKEN", "") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" app = FastAPI(title="img2img-fastapi-space", version="0.1.0") _PIPE: Any = None def _load_pipeline() -> Any: global _PIPE if _PIPE is not None: return _PIPE model_path = Path(LOCAL_MODEL_DIR) dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32 device_map = "cuda" if DEVICE == "cuda" else "cpu" # Try local model first if it exists if model_path.exists(): _PIPE = DiffusionPipeline.from_pretrained( str(model_path), torch_dtype=dtype, device_map=device_map, local_files_only=True, ) return _PIPE # Fall back to downloading from Hub with authentication if not HF_TOKEN: raise RuntimeError( f"Model not found locally at {model_path} and HF_TOKEN not set. " "Please either:\n" "1. Pre-download model and set LOCAL_MODEL_DIR, or\n" "2. Set HF_TOKEN environment variable with a read token " "(https://huggingface.co/settings/tokens)\n" "Note: black-forest-labs/FLUX.2-klein-9b-fp8 is a gated model - " "you must first accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8" ) _PIPE = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map=device_map, token=HF_TOKEN, ) return _PIPE @app.get("/health") def health() -> dict[str, Any]: return { "ok": True, "model": MODEL_ID, "localModelDir": LOCAL_MODEL_DIR, "device": DEVICE, } @app.post("/generate") async def generate( init_image: UploadFile = File(...), prompt: str = Form(...), negative_prompt: str = Form(""), strength: float = Form(0.65), guidance_scale: float = Form(7.0), num_steps: int = Form(30), seed: int | None = Form(None), face_image: UploadFile | None = File(None), ) -> Response: _ = face_image # compatibility arg for upstream caller try: raw = await init_image.read() image = Image.open(BytesIO(raw)).convert("RGB") except (UnidentifiedImageError, OSError, ValueError): return JSONResponse(status_code=400, content={"error": "init_image is not a valid image"}) pipe = _load_pipeline() if seed is None: seed = random.randint(0, 2_147_483_647) generator = torch.Generator(device=DEVICE if DEVICE == "cuda" else "cpu").manual_seed(int(seed)) result = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt or None, strength=float(max(0.0, min(1.0, strength))), guidance_scale=float(max(0.0, guidance_scale)), num_inference_steps=int(max(1, num_steps)), generator=generator, ) out = result.images[0].convert("RGB") buf = BytesIO() out.save(buf, format="PNG") return Response(content=buf.getvalue(), media_type="image/png") if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)