from fastapi import FastAPI, HTTPException, Header, Depends, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse from pydantic import BaseModel, Field from contextlib import asynccontextmanager import torch, io, base64, zipfile from PIL import Image from .model import LDMPipeline from torchvision.utils import save_image import torchvision.transforms as T import os pipeline = None @asynccontextmanager async def lifespan(app: FastAPI): global pipeline pipeline = LDMPipeline() yield del pipeline app = FastAPI(title="LDM Image Generation API", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:3000", "http://localhost:5173", "https://huggingface.co", "https://rohan3-flickr8k-frontend.hf.space" ], allow_methods=["*"], allow_headers=["*"], ) class GenerateRequest(BaseModel): caption: str = Field(..., example="a white dog running in snow") num_images: int = Field(4, ge=1, le=8) num_steps: int = Field(30, ge=10, le=100) guidance_scale: float = Field(5, ge=1.0, le=20.0) seed: int = Field(42) eta: float = Field(0, ge=0.0, le=1.0) def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image: img = img_tensor.clamp(0, 1) img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() return Image.fromarray(img) @app.get("/health") def health(): return {"status": "ok", "device": str(pipeline.device)} class GenerateResponse(BaseModel): images: list[str] num_generated: int def tensor_to_base64(img_tensor: torch.Tensor) -> str: img = img_tensor.clamp(0, 1) img = (img * 255).byte().permute(1, 2, 0).cpu().numpy() pil_img = Image.fromarray(img) buf = io.BytesIO() pil_img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") API_KEY = os.getenv("API_KEY") async def verify_key(request: Request, x_api_key: str = Header(...)): if request.method == "OPTIONS": return if x_api_key != API_KEY: raise HTTPException(status_code=403, detail="Invalid API key") @app.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest, _=Depends(verify_key)): try: images = pipeline.generate( caption=req.caption, num_images=req.num_images, num_steps=req.num_steps, guidance_scale=req.guidance_scale, seed=req.seed, eta=req.eta, ) b64_images = [tensor_to_base64(img) for img in images] return GenerateResponse(images=b64_images, num_generated=len(b64_images)) except Exception as e: raise HTTPException(status_code=500, detail=str(e))