File size: 1,986 Bytes
98b76cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import traceback
import os
from PIL import Image
import io
import base64
from fastapi import FastAPI, HTTPException, File, Form, UploadFile
from pydantic import BaseModel, Field
from diffusers import Flux2KleinPipeline
from fastapi.middleware.cors import CORSMiddleware
from typing import List

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

print("Loading model...")
pipe = Flux2KleinPipeline.from_pretrained(
    "black-forest-labs/FLUX.2-klein-base-4B",
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
pipe.load_lora_weights("./sking_v73_flux_4b_000027000.safetensors")

print("Model loaded.")

class GenerateResponse(BaseModel):
    images_base64: List[str]

@app.post("/api/generate", response_model=GenerateResponse)
async def generate(prompt: str=Form(...), guidance: float=Form(4.0), seed: int=Form(42), n_step: int = Form(100),file=File(...)):
    images = []
    content = await file.read()
    img = Image.open(io.BytesIO(content)).convert("RGBA")
    images.append(img)
    try:
        pipeline_output = pipe(
            image=images[0],
            prompt="",
            height=768,
            width=768,
            num_inference_steps=n_step,
            guidance_scale=guidance,
            num_images_per_prompt=1,
            generator=torch.Generator("cuda").manual_seed(seed)
        )
        
        images = pipeline_output.images

        b64_list = []
        for img in images:
            buffered = io.BytesIO()
            img.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode()
            b64_list.append(img_str)

        return {"images_base64": b64_list}

    except Exception as e:
        traceback.print_exc()
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=10012)