| import argparse |
| import base64 |
| import io |
| import time |
| import torch |
| import uvicorn |
| import gc |
| import asyncio |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from diffusers import FluxPipeline |
| from nunchaku import NunchakuFluxTransformer2dModel |
|
|
| |
| parser = argparse.ArgumentParser(description="Flux Image Generation Server with Nunchaku") |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") |
| parser.add_argument("--port", type=int, default=8000, help="Port to bind to") |
| parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev", help="Path or Repo ID of the base model") |
| parser.add_argument("--optimized-model", type=str, required=True, help="Path to the optimized Nunchaku model safetensors file") |
| args = parser.parse_args() |
|
|
| app = FastAPI() |
|
|
| |
| pipeline = None |
| request_lock = asyncio.Lock() |
|
|
| def load_model(): |
| global pipeline |
| |
| print(f"Loading base model from {args.model}...") |
| print(f"Loading optimized transformer from {args.optimized_model}...") |
| |
| try: |
| |
| transformer = NunchakuFluxTransformer2dModel.from_pretrained(args.optimized_model) |
| |
| |
| pipeline = FluxPipeline.from_pretrained( |
| args.model, |
| transformer=transformer, |
| torch_dtype=torch.bfloat16, |
| ).to("cuda") |
| |
| pipeline.transformer.set_attention_backend("flash") |
| pipeline.enable_model_cpu_offload() |
| pipeline.enable_vae_tiling() |
| pipeline.enable_vae_slicing() |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise e |
| |
| print("Model loaded successfully!") |
|
|
| def flush(): |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| class ImageGenerationRequest(BaseModel): |
| prompt: str |
| n: int = 1 |
| size: str = "1024x1024" |
| response_format: str = "b64_json" |
| quality: str = "standard" |
| style: str = "vivid" |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| load_model() |
|
|
| @app.post("/v1/images/generations") |
| async def generate_image(request: ImageGenerationRequest): |
| if not pipeline: |
| raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
| async with request_lock: |
| print(f"Received request: {request.prompt}") |
|
|
| |
| try: |
| width, height = map(int, request.size.split("x")) |
| except ValueError: |
| width, height = 1024, 1024 |
|
|
| |
| |
| |
| width = (width // 16) * 16 |
| height = (height // 16) * 16 |
|
|
| images = [] |
| |
| try: |
| |
| generated_images = pipeline( |
| request.prompt, |
| height=height, |
| width=width, |
| num_inference_steps=4, |
| guidance_scale=3.5, |
| num_images_per_prompt=request.n |
| ).images |
| |
| for image in generated_images: |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| images.append({"b64_json": img_str}) |
| |
| except Exception as e: |
| print(f"Error during generation: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| flush() |
|
|
| return { |
| "created": int(time.time()), |
| "data": images |
| } |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host=args.host, port=args.port) |
|
|