Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Header, Depends, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from pydantic import BaseModel, Field | |
| from contextlib import asynccontextmanager | |
| import torch, io, base64, zipfile | |
| from PIL import Image | |
| from .model import LDMPipeline | |
| from torchvision.utils import save_image | |
| import torchvision.transforms as T | |
| import os | |
| pipeline = None | |
| async def lifespan(app: FastAPI): | |
| global pipeline | |
| pipeline = LDMPipeline() | |
| yield | |
| del pipeline | |
| app = FastAPI(title="LDM Image Generation API", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://localhost:5173", | |
| "https://huggingface.co", | |
| "https://rohan3-flickr8k-frontend.hf.space" | |
| ], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class GenerateRequest(BaseModel): | |
| caption: str = Field(..., example="a white dog running in snow") | |
| num_images: int = Field(4, ge=1, le=8) | |
| num_steps: int = Field(30, ge=10, le=100) | |
| guidance_scale: float = Field(5, ge=1.0, le=20.0) | |
| seed: int = Field(42) | |
| eta: float = Field(0, ge=0.0, le=1.0) | |
| def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image: | |
| img = img_tensor.clamp(0, 1) | |
| img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() | |
| return Image.fromarray(img) | |
| def health(): | |
| return {"status": "ok", "device": str(pipeline.device)} | |
| class GenerateResponse(BaseModel): | |
| images: list[str] | |
| num_generated: int | |
| def tensor_to_base64(img_tensor: torch.Tensor) -> str: | |
| img = img_tensor.clamp(0, 1) | |
| img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() | |
| pil_img = Image.fromarray(img) | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| API_KEY = os.getenv("API_KEY") | |
| async def verify_key(request: Request, x_api_key: str = Header(...)): | |
| if request.method == "OPTIONS": | |
| return | |
| if x_api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid API key") | |
| async def generate(req: GenerateRequest, _=Depends(verify_key)): | |
| try: | |
| images = pipeline.generate( | |
| caption=req.caption, | |
| num_images=req.num_images, | |
| num_steps=req.num_steps, | |
| guidance_scale=req.guidance_scale, | |
| seed=req.seed, | |
| eta=req.eta, | |
| ) | |
| b64_images = [tensor_to_base64(img) for img in images] | |
| return GenerateResponse(images=b64_images, num_generated=len(b64_images)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |