| import torch |
| import traceback |
| import os |
| from PIL import Image |
| import io |
| import base64 |
| from fastapi import FastAPI, HTTPException, File, Form, UploadFile |
| from pydantic import BaseModel, Field |
| from diffusers import Flux2KleinPipeline |
| from fastapi.middleware.cors import CORSMiddleware |
| from typing import List |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| print("Loading model...") |
| pipe = Flux2KleinPipeline.from_pretrained( |
| "black-forest-labs/FLUX.2-klein-base-4B", |
| torch_dtype=torch.bfloat16, |
| ) |
| pipe.to("cuda") |
| pipe.load_lora_weights("./sking_v73_flux_4b_000027000.safetensors") |
|
|
| print("Model loaded.") |
|
|
| class GenerateResponse(BaseModel): |
| images_base64: List[str] |
|
|
| @app.post("/api/generate", response_model=GenerateResponse) |
| async def generate(prompt: str=Form(...), guidance: float=Form(4.0), seed: int=Form(42), n_step: int = Form(100),file=File(...)): |
| images = [] |
| content = await file.read() |
| img = Image.open(io.BytesIO(content)).convert("RGBA") |
| images.append(img) |
| try: |
| pipeline_output = pipe( |
| image=images[0], |
| prompt="", |
| height=768, |
| width=768, |
| num_inference_steps=n_step, |
| guidance_scale=guidance, |
| num_images_per_prompt=1, |
| generator=torch.Generator("cuda").manual_seed(seed) |
| ) |
| |
| images = pipeline_output.images |
|
|
| b64_list = [] |
| for img in images: |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode() |
| b64_list.append(img_str) |
|
|
| return {"images_base64": b64_list} |
|
|
| except Exception as e: |
| traceback.print_exc() |
| print(f"Error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=10012) |
|
|