EntropyDrop commited on
Commit
98b76cc
·
1 Parent(s): 392a048
Files changed (1) hide show
  1. infer.py +71 -0
infer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import traceback
3
+ import os
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ from fastapi import FastAPI, HTTPException, File, Form, UploadFile
8
+ from pydantic import BaseModel, Field
9
+ from diffusers import Flux2KleinPipeline
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from typing import List
12
+
13
+ app = FastAPI()
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ print("Loading model...")
23
+ pipe = Flux2KleinPipeline.from_pretrained(
24
+ "black-forest-labs/FLUX.2-klein-base-4B",
25
+ torch_dtype=torch.bfloat16,
26
+ )
27
+ pipe.to("cuda")
28
+ pipe.load_lora_weights("./sking_v73_flux_4b_000027000.safetensors")
29
+
30
+ print("Model loaded.")
31
+
32
+ class GenerateResponse(BaseModel):
33
+ images_base64: List[str]
34
+
35
+ @app.post("/api/generate", response_model=GenerateResponse)
36
+ async def generate(prompt: str=Form(...), guidance: float=Form(4.0), seed: int=Form(42), n_step: int = Form(100),file=File(...)):
37
+ images = []
38
+ content = await file.read()
39
+ img = Image.open(io.BytesIO(content)).convert("RGBA")
40
+ images.append(img)
41
+ try:
42
+ pipeline_output = pipe(
43
+ image=images[0],
44
+ prompt="",
45
+ height=768,
46
+ width=768,
47
+ num_inference_steps=n_step,
48
+ guidance_scale=guidance,
49
+ num_images_per_prompt=1,
50
+ generator=torch.Generator("cuda").manual_seed(seed)
51
+ )
52
+
53
+ images = pipeline_output.images
54
+
55
+ b64_list = []
56
+ for img in images:
57
+ buffered = io.BytesIO()
58
+ img.save(buffered, format="PNG")
59
+ img_str = base64.b64encode(buffered.getvalue()).decode()
60
+ b64_list.append(img_str)
61
+
62
+ return {"images_base64": b64_list}
63
+
64
+ except Exception as e:
65
+ traceback.print_exc()
66
+ print(f"Error: {e}")
67
+ raise HTTPException(status_code=500, detail=str(e))
68
+
69
+ if __name__ == "__main__":
70
+ import uvicorn
71
+ uvicorn.run(app, host="0.0.0.0", port=10012)