Image_Tagger / app.py
stephenebert's picture
Update app.py
7667eb1 verified
raw
history blame
2.56 kB
from __future__ import annotations
# FastAPI REST API + Gradio UI at /
# Endpoints:
# GET /healthz
# POST /upload -> {filename, caption, tags}
# UI:
# / (upload image, choose top_k, see caption + tags)
# Docs:
# /docs
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List
from pathlib import Path
from PIL import Image
import io
import gradio as gr
from tagger import tag_pil_image # returns (caption: str, tags: List[str])
app = FastAPI(title="Image Tagger API", version="0.4.3")
# ---------- Pydantic model ----------
class TagOut(BaseModel):
filename: str
caption: str
tags: List[str]
# ---------- Health ----------
@app.get("/healthz")
def healthz():
return {"ok": True}
# ---------- REST endpoint ----------
@app.post("/upload", response_model=TagOut)
async def upload(
file: UploadFile = File(...),
top_k: int = Query(5, ge=1, le=20, description="Max number of tags"),
):
if file.content_type not in {"image/png", "image/jpeg", "image/webp"}:
raise HTTPException(
status_code=415, detail="Only PNG, JPEG, or WebP images are supported"
)
try:
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Could not decode image")
stem = Path(file.filename).stem or "upload"
try:
caption, tags = tag_pil_image(img, stem, top_k=top_k)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Tagging failed: {e}")
return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags})
# ---------- Gradio UI at root ----------
def _ui_tag(image: Image.Image, top_k: int):
if image is None:
return "", ""
caption, tags = tag_pil_image(image.convert("RGB"), "upload", top_k=top_k)
return caption, ", ".join(tags)
demo = gr.Interface(
fn=_ui_tag,
inputs=[
gr.Image(type="pil", label="Upload image"),
gr.Slider(1, 20, value=5, step=1, label="Top-k tags"),
],
outputs=[
gr.Textbox(label="Caption", lines=2),
gr.Textbox(label="Tags (comma-separated)", lines=2),
],
flagging_mode="never",
title="Image Tagger",
description="Upload an image to get a caption and top-k tags. Programmatic API at /docs.",
)
# Mount Gradio on the same FastAPI app at root (/) to avoid redirects
app = gr.mount_gradio_app(app, demo, path="/")