import torch from diffusers import StableDiffusionPipeline from fastapi import FastAPI, Response from fastapi.responses import JSONResponse from pydantic import BaseModel import io import base64 from typing import Optional import uvicorn import os # Set cache directories to writable locations os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["HF_HOME"] = "/tmp/hf_home" os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets_cache" # Initialize FastAPI app app = FastAPI(title="Stable Diffusion API") # Define input model class TextToImageRequest(BaseModel): prompt: str negative_prompt: Optional[str] = None num_inference_steps: Optional[int] = 50 guidance_scale: Optional[float] = 7.5 height: Optional[int] = 512 width: Optional[int] = 512 seed: Optional[int] = None # Load the model (will be loaded when the Space is initialized) model_id = "CompVis/stable-diffusion-v1-4" # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" pipe = StableDiffusionPipeline.from_pretrained( model_id, cache_dir="/tmp/diffusers_cache", token=os.environ.get("HF_TOKEN") # Use token from environment variable ) pipe = pipe.to(device) @app.get("/") def read_root(): return {"message": "Stable Diffusion API is running. Use POST /generate endpoint."} @app.post("/generate") async def generate_image(request: TextToImageRequest): try: # Set seed if provided if request.seed is not None: generator = torch.Generator(device=device).manual_seed(request.seed) else: generator = None # Generate image image = pipe( prompt=request.prompt, negative_prompt=request.negative_prompt, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, height=request.height, width=request.width, generator=generator ).images[0] # Convert to base64 buffer = io.BytesIO() image.save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") return JSONResponse({ "status": "success", "image": img_str, "parameters": { "prompt": request.prompt, "negative_prompt": request.negative_prompt, "steps": request.num_inference_steps, "guidance_scale": request.guidance_scale, "dimensions": f"{request.width}x{request.height}", "seed": request.seed } }) except Exception as e: return JSONResponse( status_code=500, content={"status": "error", "message": str(e)} ) # For local testing, not necessary in Spaces if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)