stpete2's picture
Update app.py
61a06ba verified
import torch
from diffusers import StableDiffusionPipeline
from fastapi import FastAPI, Response
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import io
import base64
from typing import Optional
import uvicorn
import os
# Set cache directories to writable locations
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets_cache"
# Initialize FastAPI app
app = FastAPI(title="Stable Diffusion API")
# Define input model
class TextToImageRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: Optional[float] = 7.5
height: Optional[int] = 512
width: Optional[int] = 512
seed: Optional[int] = None
# Load the model (will be loaded when the Space is initialized)
model_id = "CompVis/stable-diffusion-v1-4"
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
cache_dir="/tmp/diffusers_cache",
token=os.environ.get("HF_TOKEN") # Use token from environment variable
)
pipe = pipe.to(device)
@app.get("/")
def read_root():
return {"message": "Stable Diffusion API is running. Use POST /generate endpoint."}
@app.post("/generate")
async def generate_image(request: TextToImageRequest):
try:
# Set seed if provided
if request.seed is not None:
generator = torch.Generator(device=device).manual_seed(request.seed)
else:
generator = None
# Generate image
image = pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
height=request.height,
width=request.width,
generator=generator
).images[0]
# Convert to base64
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
return JSONResponse({
"status": "success",
"image": img_str,
"parameters": {
"prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"steps": request.num_inference_steps,
"guidance_scale": request.guidance_scale,
"dimensions": f"{request.width}x{request.height}",
"seed": request.seed
}
})
except Exception as e:
return JSONResponse(
status_code=500,
content={"status": "error", "message": str(e)}
)
# For local testing, not necessary in Spaces
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)