flickr8k-backend / app /main.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
from fastapi import FastAPI, HTTPException, Header, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
import torch, io, base64, zipfile
from PIL import Image
from .model import LDMPipeline
from torchvision.utils import save_image
import torchvision.transforms as T
import os
pipeline = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pipeline
pipeline = LDMPipeline()
yield
del pipeline
app = FastAPI(title="LDM Image Generation API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:5173",
"https://huggingface.co",
"https://rohan3-flickr8k-frontend.hf.space"
],
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
caption: str = Field(..., example="a white dog running in snow")
num_images: int = Field(4, ge=1, le=8)
num_steps: int = Field(30, ge=10, le=100)
guidance_scale: float = Field(5, ge=1.0, le=20.0)
seed: int = Field(42)
eta: float = Field(0, ge=0.0, le=1.0)
def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image:
img = img_tensor.clamp(0, 1)
img = (img * 255).byte().permute(1, 2, 0).cpu().numpy()
return Image.fromarray(img)
@app.get("/health")
def health():
return {"status": "ok", "device": str(pipeline.device)}
class GenerateResponse(BaseModel):
images: list[str]
num_generated: int
def tensor_to_base64(img_tensor: torch.Tensor) -> str:
img = img_tensor.clamp(0, 1)
img = (img * 255).byte().permute(1, 2, 0).cpu().numpy()
pil_img = Image.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
API_KEY = os.getenv("API_KEY")
async def verify_key(request: Request, x_api_key: str = Header(...)):
if request.method == "OPTIONS":
return
if x_api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
@app.post("/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest, _=Depends(verify_key)):
try:
images = pipeline.generate(
caption=req.caption,
num_images=req.num_images,
num_steps=req.num_steps,
guidance_scale=req.guidance_scale,
seed=req.seed,
eta=req.eta,
)
b64_images = [tensor_to_base64(img) for img in images]
return GenerateResponse(images=b64_images, num_generated=len(b64_images))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))