| import os |
| import random |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| import uvicorn |
| from diffusers import DiffusionPipeline |
| from fastapi import FastAPI, File, Form, UploadFile |
| from fastapi.responses import JSONResponse, Response |
| from PIL import Image, UnidentifiedImageError |
|
|
| MODEL_ID = os.getenv("MODEL_ID", "black-forest-labs/FLUX.2-klein-9b-fp8") |
| LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "./models/flux2-klein-9b-fp8") |
| HF_TOKEN = os.getenv("HF_TOKEN", "") |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| app = FastAPI(title="img2img-fastapi-space", version="0.1.0") |
| _PIPE: Any = None |
|
|
|
|
| def _load_pipeline() -> Any: |
| global _PIPE |
| if _PIPE is not None: |
| return _PIPE |
|
|
| model_path = Path(LOCAL_MODEL_DIR) |
| dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32 |
| device_map = "cuda" if DEVICE == "cuda" else "cpu" |
|
|
| |
| if model_path.exists(): |
| _PIPE = DiffusionPipeline.from_pretrained( |
| str(model_path), |
| torch_dtype=dtype, |
| device_map=device_map, |
| local_files_only=True, |
| ) |
| return _PIPE |
|
|
| |
| if not HF_TOKEN: |
| raise RuntimeError( |
| f"Model not found locally at {model_path} and HF_TOKEN not set. " |
| "Please either:\n" |
| "1. Pre-download model and set LOCAL_MODEL_DIR, or\n" |
| "2. Set HF_TOKEN environment variable with a read token " |
| "(https://huggingface.co/settings/tokens)\n" |
| "Note: black-forest-labs/FLUX.2-klein-9b-fp8 is a gated model - " |
| "you must first accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8" |
| ) |
|
|
| _PIPE = DiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| device_map=device_map, |
| token=HF_TOKEN, |
| ) |
| return _PIPE |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, Any]: |
| return { |
| "ok": True, |
| "model": MODEL_ID, |
| "localModelDir": LOCAL_MODEL_DIR, |
| "device": DEVICE, |
| } |
|
|
|
|
| @app.post("/generate") |
| async def generate( |
| init_image: UploadFile = File(...), |
| prompt: str = Form(...), |
| negative_prompt: str = Form(""), |
| strength: float = Form(0.65), |
| guidance_scale: float = Form(7.0), |
| num_steps: int = Form(30), |
| seed: int | None = Form(None), |
| face_image: UploadFile | None = File(None), |
| ) -> Response: |
| _ = face_image |
|
|
| try: |
| raw = await init_image.read() |
| image = Image.open(BytesIO(raw)).convert("RGB") |
| except (UnidentifiedImageError, OSError, ValueError): |
| return JSONResponse(status_code=400, content={"error": "init_image is not a valid image"}) |
|
|
| pipe = _load_pipeline() |
|
|
| if seed is None: |
| seed = random.randint(0, 2_147_483_647) |
| generator = torch.Generator(device=DEVICE if DEVICE == "cuda" else "cpu").manual_seed(int(seed)) |
|
|
| result = pipe( |
| image=image, |
| prompt=prompt, |
| negative_prompt=negative_prompt or None, |
| strength=float(max(0.0, min(1.0, strength))), |
| guidance_scale=float(max(0.0, guidance_scale)), |
| num_inference_steps=int(max(1, num_steps)), |
| generator=generator, |
| ) |
| out = result.images[0].convert("RGB") |
| buf = BytesIO() |
| out.save(buf, format="PNG") |
| return Response(content=buf.getvalue(), media_type="image/png") |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |
|
|
|
|