File size: 3,955 Bytes
1e103b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import base64
import io
import time
import torch
import uvicorn
import gc
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel

# Argument parsing
parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model")
parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file")
args = parser.parse_args()

app = FastAPI()

# Global components
pipeline = None
request_lock = asyncio.Lock()

def load_model():
    global pipeline
    
    print(f"Loading base model from {args.model}...")
    print(f"Loading optimized transformer from {args.optimized_model}...")
    
    try:
        # Load the optimized transformer
        transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model)
        
        # Load the pipeline with the optimized transformer
        pipeline = FluxPipeline.from_pretrained(
            args.model,
            transformer=transformer,
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        
        pipeline.transformer.set_attention_backend("flash")
        pipeline.enable_model_cpu_offload()
        pipeline.enable_vae_tiling()
        pipeline.enable_vae_slicing()
        
    except Exception as e:
        print(f"Error loading model: {e}")
        raise e
    
    print("Model loaded successfully!")

def flush():
    gc.collect()
    torch.cuda.empty_cache()

class ImageGenerationRequest(BaseModel):
    prompt: str
    n: int = 1
    size: str = "1024x1024"
    response_format: str = "b64_json"
    quality: str = "standard"
    style: str = "vivid"

@app.on_event("startup")
async def startup_event():
    load_model()

@app.post("/v1/images/generations")
async def generate_image(request: ImageGenerationRequest):
    if not pipeline:
        raise HTTPException(status_code=500, detail="Model not loaded")

    async with request_lock:
        print(f"Received request: {request.prompt}")

        # Parse size
        try:
            width, height = map(int, request.size.split("x"))
        except ValueError:
            width, height = 1024, 1024

        # Flux requires dimensions to be multiples of 16 (or 8 depending on VAE)
        # Standard Flux dev usually works well with 1024x1024
        # We'll ensure they are divisible by 16 just in case
        width = (width // 16) * 16
        height = (height // 16) * 16

        images = []
        
        try:
            # Generate images
            generated_images = pipeline(
                request.prompt,
                height=height,
                width=width,
                num_inference_steps=4, # Standard for Flux Dev
                guidance_scale=3.5,     # Nunchaku example uses 3.5, previous code used 4.0. Let's stick to 3.5 or 4.0. Example says 3.5.
                num_images_per_prompt=request.n
            ).images
            
            for image in generated_images:
                buffered = io.BytesIO()
                image.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
                images.append({"b64_json": img_str})
                
        except Exception as e:
            print(f"Error during generation: {e}")
            raise HTTPException(status_code=500, detail=str(e))
        finally:
            flush()

        return {
            "created": int(time.time()),
            "data": images
        }

if __name__ == "__main__":
    uvicorn.run(app, host=args.host, port=args.port)