Spaces:
Paused
Paused
| import numpy as np | |
| import random | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| class InferenceRequest(BaseModel): | |
| prompt: str | |
| seed: int = 42 | |
| randomize_seed: bool = False | |
| width: int = 1024 | |
| height: int = 1024 | |
| num_inference_steps: int = 4 | |
| class InferenceResponse(BaseModel): | |
| image: str | |
| seed: int | |
| app = FastAPI() | |
| async def infer(request: InferenceRequest): | |
| if request.randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| else: | |
| seed = request.seed | |
| if not (256 <= request.width <= MAX_IMAGE_SIZE) or not (256 <= request.height <= MAX_IMAGE_SIZE): | |
| raise HTTPException(status_code=400, detail="Width and height must be between 256 and 2048.") | |
| generator = torch.Generator().manual_seed(seed) | |
| image = pipe( | |
| prompt=request.prompt, | |
| width=request.width, | |
| height=request.height, | |
| num_inference_steps=request.num_inference_steps, | |
| generator=generator, | |
| guidance_scale=0.0 | |
| ).images[0] | |
| # Convert image to base64 | |
| image_base64 = image_to_base64(image) | |
| return InferenceResponse(image=image_base64, seed=seed) | |
| def image_to_base64(image): | |
| import io | |
| import base64 | |
| from PIL import Image | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return img_str | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |