Spaces:
Sleeping
Sleeping
| # server.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import logging | |
| import os | |
| app = FastAPI() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("uvicorn") | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Load the model (Stable Diffusion) | |
| logger.info("Loading model...") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16, | |
| cache_dir="/tmp/huggingface", | |
| use_auth_token=hf_token | |
| ) | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| async def generate_image(data: PromptRequest): | |
| try: | |
| logger.info("Received Request. Generating Image...") | |
| image = pipe(data.prompt).images[0] | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return {"image": img_str} | |
| logger.info("Done") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |