Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| # FastAPI REST API + Gradio UI at / | |
| # Endpoints: | |
| # GET /healthz | |
| # POST /upload -> {filename, caption, tags} | |
| # UI: | |
| # / (upload image, choose top_k, see caption + tags) | |
| # Docs: | |
| # /docs | |
| from fastapi import FastAPI, File, HTTPException, Query, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from typing import List | |
| from pathlib import Path | |
| from PIL import Image | |
| import io | |
| import gradio as gr | |
| from tagger import tag_pil_image # returns (caption: str, tags: List[str]) | |
| app = FastAPI(title="Image Tagger API", version="0.4.3") | |
| # ---------- Pydantic model ---------- | |
| class TagOut(BaseModel): | |
| filename: str | |
| caption: str | |
| tags: List[str] | |
| # ---------- Health ---------- | |
| def healthz(): | |
| return {"ok": True} | |
| # ---------- REST endpoint ---------- | |
| async def upload( | |
| file: UploadFile = File(...), | |
| top_k: int = Query(5, ge=1, le=20, description="Max number of tags"), | |
| ): | |
| if file.content_type not in {"image/png", "image/jpeg", "image/webp"}: | |
| raise HTTPException( | |
| status_code=415, detail="Only PNG, JPEG, or WebP images are supported" | |
| ) | |
| try: | |
| data = await file.read() | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not decode image") | |
| stem = Path(file.filename).stem or "upload" | |
| try: | |
| caption, tags = tag_pil_image(img, stem, top_k=top_k) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Tagging failed: {e}") | |
| return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags}) | |
| # ---------- Gradio UI at root ---------- | |
| def _ui_tag(image: Image.Image, top_k: int): | |
| if image is None: | |
| return "", "" | |
| caption, tags = tag_pil_image(image.convert("RGB"), "upload", top_k=top_k) | |
| return caption, ", ".join(tags) | |
| demo = gr.Interface( | |
| fn=_ui_tag, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload image"), | |
| gr.Slider(1, 20, value=5, step=1, label="Top-k tags"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Caption", lines=2), | |
| gr.Textbox(label="Tags (comma-separated)", lines=2), | |
| ], | |
| flagging_mode="never", | |
| title="Image Tagger", | |
| description="Upload an image to get a caption and top-k tags. Programmatic API at /docs.", | |
| ) | |
| # Mount Gradio on the same FastAPI app at root (/) to avoid redirects | |
| app = gr.mount_gradio_app(app, demo, path="/") | |