text2img / app.py
whitepeacock's picture
Update app.py
f96382c verified
import io
import os
import base64
import asyncio
import random
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from PIL import Image
import torch
from diffusers import DiffusionPipeline
# -------------------------------------------------------------
# HuggingFace Token
# -------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
# -------------------------------------------------------------
# Model Settings
# -------------------------------------------------------------
MODEL_REPO = "stabilityai/sdxl-turbo"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Loading {MODEL_REPO} on {device}...")
pipe = DiffusionPipeline.from_pretrained(
MODEL_REPO,
torch_dtype=dtype,
use_safetensors=True,
token=HF_TOKEN if HF_TOKEN else None,
)
pipe.to(device)
if device == "cpu":
try:
pipe.enable_model_cpu_offload()
except:
pass
print("Model ready.")
# -------------------------------------------------------------
# Automatic Negative Prompt (backend only)
# -------------------------------------------------------------
AUTO_NEGATIVE_PROMPT = (
"low quality, worst quality, blurry, pixelated, jpeg artifacts, "
"deformed, distorted, bad anatomy, extra fingers, extra limbs, "
"missing fingers, watermark, text, logo"
)
# -------------------------------------------------------------
# Core Generation Function
# -------------------------------------------------------------
def generate_image(prompt, seed, width, height, steps, guidance):
generator = torch.Generator(device=device).manual_seed(seed)
result = pipe(
prompt=prompt,
negative_prompt=AUTO_NEGATIVE_PROMPT,
guidance_scale=guidance,
num_inference_steps=steps,
width=width,
height=height,
generator=generator,
)
return result.images[0]
# -------------------------------------------------------------
# Async Queue
# -------------------------------------------------------------
executor = ThreadPoolExecutor(max_workers=2)
semaphore = asyncio.Semaphore(2)
async def run_generate(prompt, seed, width, height, steps, guidance):
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
generate_image,
prompt,
seed,
width,
height,
steps,
guidance,
)
# -------------------------------------------------------------
# FastAPI App
# -------------------------------------------------------------
app = FastAPI(title="SDXL Turbo Generator", version="2.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------------------------------------
# UI
# -------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def home():
return """
<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>SDXL Turbo</title>
<style>
body {
font-family: Arial;
max-width: 900px;
margin: 30px auto;
}
textarea {
width: 100%;
padding: 12px;
margin-bottom: 10px;
font-size: 15px;
}
button {
padding: 12px 18px;
background: black;
color: white;
border: none;
cursor: pointer;
font-size: 15px;
}
#status {
margin-top: 12px;
}
#output {
margin-top: 20px;
width: 100%;
height: 432px;
border: 1px solid #ddd;
border-radius: 10px;
display: flex;
align-items: center;
justify-content: center;
background: #fafafa;
}
#output img {
max-width: 100%;
max-height: 100%;
border-radius: 8px;
}
</style>
</head>
<body>
<h1>SDXL Turbo</h1>
<textarea id="prompt" placeholder="Enter prompt"></textarea>
<button onclick="send()">Generate</button>
<div id="status"></div>
<div id="output">
<span id="placeholder">Image will appear here</span>
<img id="result" style="display:none;" />
</div>
<script>
async function send() {
const prompt = document.getElementById("prompt").value;
const status = document.getElementById("status");
const img = document.getElementById("result");
const placeholder = document.getElementById("placeholder");
status.innerText = "Generating...";
img.style.display = "none";
placeholder.style.display = "block";
const res = await fetch("/api/generate", {
method: "POST",
headers: {"Content-Type": "application/json"},
body: JSON.stringify({ prompt })
});
const data = await res.json();
if (data.status !== "success") {
status.innerText = "Error: " + data.message;
return;
}
img.src = "data:image/png;base64," + data.image_base64;
img.style.display = "block";
placeholder.style.display = "none";
status.innerText = "Done (seed " + data.seed + ")";
}
</script>
</body>
</html>
"""
# -------------------------------------------------------------
# API Endpoint
# -------------------------------------------------------------
@app.post("/api/generate")
async def api_generate(request: Request):
try:
body = await request.json()
prompt = body.get("prompt", "").strip()
except:
return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
if not prompt:
return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
width = 768
height = 432
steps = 2
guidance = 0.0
seed = random.randint(0, 2**31 - 1)
try:
img = await run_generate(prompt, seed, width, height, steps, guidance)
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return JSONResponse({
"status": "success",
"image_base64": b64,
"seed": seed,
"width": width,
"height": height
})
except Exception as e:
return JSONResponse({"status": "error", "message": str(e)}, 500)
# -------------------------------------------------------------
# Local run
# -------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)