test / app.py
keertan2610's picture
Update app.py
a92130d verified
# server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from diffusers import StableDiffusionPipeline
import torch
import base64
from io import BytesIO
from PIL import Image
import logging
import os
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn")
hf_token = os.getenv("HF_TOKEN")
# Load the model (Stable Diffusion)
logger.info("Loading model...")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
cache_dir="/tmp/huggingface",
use_auth_token=hf_token
)
class PromptRequest(BaseModel):
prompt: str
@app.post("/generate")
async def generate_image(data: PromptRequest):
try:
logger.info("Received Request. Generating Image...")
image = pipe(data.prompt).images[0]
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}
logger.info("Done")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))