Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import base64 | |
| import asyncio | |
| import random | |
| from concurrent.futures import ThreadPoolExecutor | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from PIL import Image | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| # ------------------------------------------------------------- | |
| # HuggingFace Token | |
| # ------------------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # ------------------------------------------------------------- | |
| # Model Settings | |
| # ------------------------------------------------------------- | |
| MODEL_REPO = "stabilityai/sdxl-turbo" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Loading {MODEL_REPO} on {device}...") | |
| pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_REPO, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| token=HF_TOKEN if HF_TOKEN else None, | |
| ) | |
| pipe.to(device) | |
| if device == "cpu": | |
| try: | |
| pipe.enable_model_cpu_offload() | |
| except: | |
| pass | |
| print("Model ready.") | |
| # ------------------------------------------------------------- | |
| # Automatic Negative Prompt (backend only) | |
| # ------------------------------------------------------------- | |
| AUTO_NEGATIVE_PROMPT = ( | |
| "low quality, worst quality, blurry, pixelated, jpeg artifacts, " | |
| "deformed, distorted, bad anatomy, extra fingers, extra limbs, " | |
| "missing fingers, watermark, text, logo" | |
| ) | |
| # ------------------------------------------------------------- | |
| # Core Generation Function | |
| # ------------------------------------------------------------- | |
| def generate_image(prompt, seed, width, height, steps, guidance): | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=AUTO_NEGATIVE_PROMPT, | |
| guidance_scale=guidance, | |
| num_inference_steps=steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ) | |
| return result.images[0] | |
| # ------------------------------------------------------------- | |
| # Async Queue | |
| # ------------------------------------------------------------- | |
| executor = ThreadPoolExecutor(max_workers=2) | |
| semaphore = asyncio.Semaphore(2) | |
| async def run_generate(prompt, seed, width, height, steps, guidance): | |
| async with semaphore: | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor( | |
| executor, | |
| generate_image, | |
| prompt, | |
| seed, | |
| width, | |
| height, | |
| steps, | |
| guidance, | |
| ) | |
| # ------------------------------------------------------------- | |
| # FastAPI App | |
| # ------------------------------------------------------------- | |
| app = FastAPI(title="SDXL Turbo Generator", version="2.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------------------------------------------- | |
| # UI | |
| # ------------------------------------------------------------- | |
| def home(): | |
| return """ | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"/> | |
| <title>SDXL Turbo</title> | |
| <style> | |
| body { | |
| font-family: Arial; | |
| max-width: 900px; | |
| margin: 30px auto; | |
| } | |
| textarea { | |
| width: 100%; | |
| padding: 12px; | |
| margin-bottom: 10px; | |
| font-size: 15px; | |
| } | |
| button { | |
| padding: 12px 18px; | |
| background: black; | |
| color: white; | |
| border: none; | |
| cursor: pointer; | |
| font-size: 15px; | |
| } | |
| #status { | |
| margin-top: 12px; | |
| } | |
| #output { | |
| margin-top: 20px; | |
| width: 100%; | |
| height: 432px; | |
| border: 1px solid #ddd; | |
| border-radius: 10px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| background: #fafafa; | |
| } | |
| #output img { | |
| max-width: 100%; | |
| max-height: 100%; | |
| border-radius: 8px; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>SDXL Turbo</h1> | |
| <textarea id="prompt" placeholder="Enter prompt"></textarea> | |
| <button onclick="send()">Generate</button> | |
| <div id="status"></div> | |
| <div id="output"> | |
| <span id="placeholder">Image will appear here</span> | |
| <img id="result" style="display:none;" /> | |
| </div> | |
| <script> | |
| async function send() { | |
| const prompt = document.getElementById("prompt").value; | |
| const status = document.getElementById("status"); | |
| const img = document.getElementById("result"); | |
| const placeholder = document.getElementById("placeholder"); | |
| status.innerText = "Generating..."; | |
| img.style.display = "none"; | |
| placeholder.style.display = "block"; | |
| const res = await fetch("/api/generate", { | |
| method: "POST", | |
| headers: {"Content-Type": "application/json"}, | |
| body: JSON.stringify({ prompt }) | |
| }); | |
| const data = await res.json(); | |
| if (data.status !== "success") { | |
| status.innerText = "Error: " + data.message; | |
| return; | |
| } | |
| img.src = "data:image/png;base64," + data.image_base64; | |
| img.style.display = "block"; | |
| placeholder.style.display = "none"; | |
| status.innerText = "Done (seed " + data.seed + ")"; | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ------------------------------------------------------------- | |
| # API Endpoint | |
| # ------------------------------------------------------------- | |
| async def api_generate(request: Request): | |
| try: | |
| body = await request.json() | |
| prompt = body.get("prompt", "").strip() | |
| except: | |
| return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400) | |
| if not prompt: | |
| return JSONResponse({"status": "error", "message": "Prompt required"}, 400) | |
| width = 768 | |
| height = 432 | |
| steps = 2 | |
| guidance = 0.0 | |
| seed = random.randint(0, 2**31 - 1) | |
| try: | |
| img = await run_generate(prompt, seed, width, height, steps, guidance) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| return JSONResponse({ | |
| "status": "success", | |
| "image_base64": b64, | |
| "seed": seed, | |
| "width": width, | |
| "height": height | |
| }) | |
| except Exception as e: | |
| return JSONResponse({"status": "error", "message": str(e)}, 500) | |
| # ------------------------------------------------------------- | |
| # Local run | |
| # ------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |