Spaces:
Sleeping
Sleeping
File size: 4,659 Bytes
ae57f86 7d79380 e7d0635 d3e85eb 3452c0c 7d79380 d3e85eb cd3cf65 3452c0c cb8b918 d3e85eb 3452c0c d3e85eb 3452c0c d3e85eb 3452c0c d3e85eb cd3cf65 a6de95a cd3cf65 cb8b918 a6de95a 3452c0c d3e85eb 3452c0c d3e85eb 3452c0c cb8b918 d3e85eb 3452c0c ae57f86 d3e85eb cd3cf65 d3e85eb cb8b918 d3e85eb cb8b918 d3e85eb cb8b918 d3e85eb cb8b918 d3e85eb cb8b918 d3e85eb cb8b918 d3e85eb a6de95a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from __future__ import annotations
import io
import os
from pathlib import Path
from typing import List
import gradio as gr
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from PIL import Image
import tagger as tg
# -------------------- FastAPI --------------------
app = FastAPI(
title="Image Tagger API",
version="1.0.0",
description="Generate a caption with BLIP, then return top-K tags derived from that caption.",
)
WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0"
class TagResponse(BaseModel):
filename: str = Field(..., examples=["photo.jpg"])
caption: str = Field(..., examples=["a lion rests on a rock in the wild"])
tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]])
@app.on_event("startup")
def _load_once() -> None:
tg.init_models()
@app.get("/healthz")
def healthz():
return {"ok": True}
@app.get("/", response_class=HTMLResponse)
def root():
return """
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>Image Tagger API</title>
<style>
body{font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, sans-serif; max-width: 820px; margin: 48px auto; padding: 0 16px;}
.card{border:1px solid #e5e7eb; border-radius:12px; padding:20px;}
.btn{background:#111; color:#fff; padding:.6rem 1rem; border-radius:10px; text-decoration:none;}
.btn:focus,.btn:hover{opacity:.9}
input[type=number]{width:80px;}
</style>
</head>
<body>
<h2>🖼️ Image Tagger API</h2>
<p>Use <a href="/docs">/docs</a> for Swagger or try the simple UI at <a class="btn" href="/ui">/ui</a>.</p>
<div class="card">
<h3>Quick upload</h3>
<form action="/upload" method="post" enctype="multipart/form-data">
<p><input type="file" name="file" accept="image/png,image/jpeg,image/webp" required></p>
<p>Top K tags: <input type="number" name="top_k" min="1" max="20" value="5"></p>
<p><button class="btn" type="submit">Upload</button></p>
</form>
</div>
</body>
</html>"""
@app.post("/upload", response_model=TagResponse)
async def upload_image(
file: UploadFile = File(...),
top_k: int = Query(5, ge=1, le=20, description="How many tags to return"),
):
try:
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
# caption with BLIP
caption = tg.caption_image(img)
# top-K tags (ensure tagger returns ONLY the list)
stem = Path(file.filename).stem
tags = tg.caption_to_tags(caption, top_k=top_k)
# optional sidecar (same content shape as JSON response)
if WRITE_SIDECAR:
try:
(Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True)
(Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text(
TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2)
)
except Exception:
# ignore filesystem errors; do not fail the request
pass
return TagResponse(filename=file.filename, caption=caption, tags=tags)
# -------------------- Gradio (mounted at /ui) --------------------
def _infer(image: Image.Image, top_k: int):
"""Wraps the same logic used by the API, but returns simple types
so the schema is trivial for Gradio (avoids JSON/dict outputs)."""
if image is None:
return "", ""
cap = tg.caption_image(image)
tags = tg.caption_to_tags(cap, top_k=top_k)
return cap, ", ".join(tags)
with gr.Blocks(title="Image Tagger UI") as demo:
gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.")
with gr.Row():
with gr.Column(scale=3):
in_img = gr.Image(type="pil", label="Upload image", height=480)
k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)")
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=2):
out_cap = gr.Textbox(label="Generated Caption", lines=2)
out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags])
clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags])
# mount Gradio under FastAPI
app = gr.mount_gradio_app(app, demo, path="/ui")
|