File size: 3,595 Bytes
d222bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import random
from io import BytesIO
from pathlib import Path
from typing import Any

import torch
import uvicorn
from diffusers import DiffusionPipeline
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import JSONResponse, Response
from PIL import Image, UnidentifiedImageError

MODEL_ID = os.getenv("MODEL_ID", "black-forest-labs/FLUX.2-klein-9b-fp8")
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "./models/flux2-klein-9b-fp8")
HF_TOKEN = os.getenv("HF_TOKEN", "")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

app = FastAPI(title="img2img-fastapi-space", version="0.1.0")
_PIPE: Any = None


def _load_pipeline() -> Any:
    global _PIPE
    if _PIPE is not None:
        return _PIPE

    model_path = Path(LOCAL_MODEL_DIR)
    dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32
    device_map = "cuda" if DEVICE == "cuda" else "cpu"

    # Try local model first if it exists
    if model_path.exists():
        _PIPE = DiffusionPipeline.from_pretrained(
            str(model_path),
            torch_dtype=dtype,
            device_map=device_map,
            local_files_only=True,
        )
        return _PIPE

    # Fall back to downloading from Hub with authentication
    if not HF_TOKEN:
        raise RuntimeError(
            f"Model not found locally at {model_path} and HF_TOKEN not set. "
            "Please either:\n"
            "1. Pre-download model and set LOCAL_MODEL_DIR, or\n"
            "2. Set HF_TOKEN environment variable with a read token "
            "(https://huggingface.co/settings/tokens)\n"
            "Note: black-forest-labs/FLUX.2-klein-9b-fp8 is a gated model - "
            "you must first accept the license at https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8"
        )

    _PIPE = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map=device_map,
        token=HF_TOKEN,
    )
    return _PIPE


@app.get("/health")
def health() -> dict[str, Any]:
    return {
        "ok": True,
        "model": MODEL_ID,
        "localModelDir": LOCAL_MODEL_DIR,
        "device": DEVICE,
    }


@app.post("/generate")
async def generate(
    init_image: UploadFile = File(...),
    prompt: str = Form(...),
    negative_prompt: str = Form(""),
    strength: float = Form(0.65),
    guidance_scale: float = Form(7.0),
    num_steps: int = Form(30),
    seed: int | None = Form(None),
    face_image: UploadFile | None = File(None),
) -> Response:
    _ = face_image  # compatibility arg for upstream caller

    try:
        raw = await init_image.read()
        image = Image.open(BytesIO(raw)).convert("RGB")
    except (UnidentifiedImageError, OSError, ValueError):
        return JSONResponse(status_code=400, content={"error": "init_image is not a valid image"})

    pipe = _load_pipeline()

    if seed is None:
        seed = random.randint(0, 2_147_483_647)
    generator = torch.Generator(device=DEVICE if DEVICE == "cuda" else "cpu").manual_seed(int(seed))

    result = pipe(
        image=image,
        prompt=prompt,
        negative_prompt=negative_prompt or None,
        strength=float(max(0.0, min(1.0, strength))),
        guidance_scale=float(max(0.0, guidance_scale)),
        num_inference_steps=int(max(1, num_steps)),
        generator=generator,
    )
    out = result.images[0].convert("RGB")
    buf = BytesIO()
    out.save(buf, format="PNG")
    return Response(content=buf.getvalue(), media_type="image/png")


if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)