Sking / infer.py
EntropyDrop
infer.py
98b76cc
import torch
import traceback
import os
from PIL import Image
import io
import base64
from fastapi import FastAPI, HTTPException, File, Form, UploadFile
from pydantic import BaseModel, Field
from diffusers import Flux2KleinPipeline
from fastapi.middleware.cors import CORSMiddleware
from typing import List
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
print("Loading model...")
pipe = Flux2KleinPipeline.from_pretrained(
"black-forest-labs/FLUX.2-klein-base-4B",
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
pipe.load_lora_weights("./sking_v73_flux_4b_000027000.safetensors")
print("Model loaded.")
class GenerateResponse(BaseModel):
images_base64: List[str]
@app.post("/api/generate", response_model=GenerateResponse)
async def generate(prompt: str=Form(...), guidance: float=Form(4.0), seed: int=Form(42), n_step: int = Form(100),file=File(...)):
images = []
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGBA")
images.append(img)
try:
pipeline_output = pipe(
image=images[0],
prompt="",
height=768,
width=768,
num_inference_steps=n_step,
guidance_scale=guidance,
num_images_per_prompt=1,
generator=torch.Generator("cuda").manual_seed(seed)
)
images = pipeline_output.images
b64_list = []
for img in images:
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
b64_list.append(img_str)
return {"images_base64": b64_list}
except Exception as e:
traceback.print_exc()
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=10012)