cs / app.py
Haiss123's picture
Update app.py
d222bfb verified
import os
import random
from io import BytesIO
from pathlib import Path
from typing import Any
import torch
import uvicorn
from diffusers import DiffusionPipeline
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import JSONResponse, Response
from PIL import Image, UnidentifiedImageError
MODEL_ID = os.getenv("MODEL_ID", "black-forest-labs/FLUX.2-klein-9b-fp8")
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "./models/flux2-klein-9b-fp8")
HF_TOKEN = os.getenv("HF_TOKEN", "")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI(title="img2img-fastapi-space", version="0.1.0")
_PIPE: Any = None
def _load_pipeline() -> Any:
global _PIPE
if _PIPE is not None:
return _PIPE
model_path = Path(LOCAL_MODEL_DIR)
dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32
device_map = "cuda" if DEVICE == "cuda" else "cpu"
# Try local model first if it exists
if model_path.exists():
_PIPE = DiffusionPipeline.from_pretrained(
str(model_path),
torch_dtype=dtype,
device_map=device_map,
local_files_only=True,
)
return _PIPE
# Fall back to downloading from Hub with authentication
if not HF_TOKEN:
raise RuntimeError(
f"Model not found locally at {model_path} and HF_TOKEN not set. "
"Please either:\n"
"1. Pre-download model and set LOCAL_MODEL_DIR, or\n"
"2. Set HF_TOKEN environment variable with a read token "
"(https://huggingface.co/settings/tokens)\n"
"Note: black-forest-labs/FLUX.2-klein-9b-fp8 is a gated model - "
"you must first accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8"
)
_PIPE = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device_map,
token=HF_TOKEN,
)
return _PIPE
@app.get("/health")
def health() -> dict[str, Any]:
return {
"ok": True,
"model": MODEL_ID,
"localModelDir": LOCAL_MODEL_DIR,
"device": DEVICE,
}
@app.post("/generate")
async def generate(
init_image: UploadFile = File(...),
prompt: str = Form(...),
negative_prompt: str = Form(""),
strength: float = Form(0.65),
guidance_scale: float = Form(7.0),
num_steps: int = Form(30),
seed: int | None = Form(None),
face_image: UploadFile | None = File(None),
) -> Response:
_ = face_image # compatibility arg for upstream caller
try:
raw = await init_image.read()
image = Image.open(BytesIO(raw)).convert("RGB")
except (UnidentifiedImageError, OSError, ValueError):
return JSONResponse(status_code=400, content={"error": "init_image is not a valid image"})
pipe = _load_pipeline()
if seed is None:
seed = random.randint(0, 2_147_483_647)
generator = torch.Generator(device=DEVICE if DEVICE == "cuda" else "cpu").manual_seed(int(seed))
result = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt or None,
strength=float(max(0.0, min(1.0, strength))),
guidance_scale=float(max(0.0, guidance_scale)),
num_inference_steps=int(max(1, num_steps)),
generator=generator,
)
out = result.images[0].convert("RGB")
buf = BytesIO()
out.save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png")
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)