from __future__ import annotations import io import os from pathlib import Path from typing import List import gradio as gr from fastapi import FastAPI, File, HTTPException, Query, UploadFile from fastapi.responses import HTMLResponse from pydantic import BaseModel, Field from PIL import Image import tagger as tg # -------------------- FastAPI -------------------- app = FastAPI( title="Image Tagger API", version="1.0.0", description="Generate a caption with BLIP, then return top-K tags derived from that caption.", ) WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0" class TagResponse(BaseModel): filename: str = Field(..., examples=["photo.jpg"]) caption: str = Field(..., examples=["a lion rests on a rock in the wild"]) tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]]) @app.on_event("startup") def _load_once() -> None: tg.init_models() @app.get("/healthz") def healthz(): return {"ok": True} @app.get("/", response_class=HTMLResponse) def root(): return """ Image Tagger API

🖼️ Image Tagger API

Use /docs for Swagger or try the simple UI at /ui.

Quick upload

Top K tags:

""" @app.post("/upload", response_model=TagResponse) async def upload_image( file: UploadFile = File(...), top_k: int = Query(5, ge=1, le=20, description="How many tags to return"), ): try: content = await file.read() img = Image.open(io.BytesIO(content)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image: {e}") # caption with BLIP caption = tg.caption_image(img) # top-K tags (ensure tagger returns ONLY the list) stem = Path(file.filename).stem tags = tg.caption_to_tags(caption, top_k=top_k) # optional sidecar (same content shape as JSON response) if WRITE_SIDECAR: try: (Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True) (Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text( TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2) ) except Exception: # ignore filesystem errors; do not fail the request pass return TagResponse(filename=file.filename, caption=caption, tags=tags) # -------------------- Gradio (mounted at /ui) -------------------- def _infer(image: Image.Image, top_k: int): """Wraps the same logic used by the API, but returns simple types so the schema is trivial for Gradio (avoids JSON/dict outputs).""" if image is None: return "", "" cap = tg.caption_image(image) tags = tg.caption_to_tags(cap, top_k=top_k) return cap, ", ".join(tags) with gr.Blocks(title="Image Tagger UI") as demo: gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.") with gr.Row(): with gr.Column(scale=3): in_img = gr.Image(type="pil", label="Upload image", height=480) k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)") submit = gr.Button("Submit", variant="primary") clear = gr.Button("Clear") with gr.Column(scale=2): out_cap = gr.Textbox(label="Generated Caption", lines=2) out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2) submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags]) clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags]) # mount Gradio under FastAPI app = gr.mount_gradio_app(app, demo, path="/ui")