File size: 4,659 Bytes
ae57f86
7d79380
e7d0635
d3e85eb
3452c0c
 
 
7d79380
d3e85eb
 
 
 
cd3cf65
3452c0c
 
cb8b918
d3e85eb
 
 
 
3452c0c
 
d3e85eb
3452c0c
d3e85eb
 
 
 
3452c0c
d3e85eb
 
 
 
cd3cf65
 
a6de95a
cd3cf65
 
 
cb8b918
a6de95a
3452c0c
d3e85eb
 
3452c0c
d3e85eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3452c0c
 
 
cb8b918
d3e85eb
3452c0c
ae57f86
d3e85eb
cd3cf65
d3e85eb
 
 
 
 
cb8b918
d3e85eb
 
cb8b918
d3e85eb
 
 
 
 
 
 
 
 
 
 
 
 
 
cb8b918
d3e85eb
cb8b918
 
d3e85eb
 
 
 
 
 
 
 
 
cb8b918
 
d3e85eb
 
cb8b918
 
d3e85eb
 
 
 
 
 
 
 
 
 
 
 
 
a6de95a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import io
import os
from pathlib import Path
from typing import List

import gradio as gr
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from PIL import Image

import tagger as tg

# -------------------- FastAPI --------------------
app = FastAPI(
    title="Image Tagger API",
    version="1.0.0",
    description="Generate a caption with BLIP, then return top-K tags derived from that caption.",
)

WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0"

class TagResponse(BaseModel):
    filename: str = Field(..., examples=["photo.jpg"])
    caption: str = Field(..., examples=["a lion rests on a rock in the wild"])
    tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]])


@app.on_event("startup")
def _load_once() -> None:
    tg.init_models()


@app.get("/healthz")
def healthz():
    return {"ok": True}


@app.get("/", response_class=HTMLResponse)
def root():
    return """
<!doctype html>
<html>
  <head>
    <meta charset="utf-8" />
    <title>Image Tagger API</title>
    <style>
      body{font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, sans-serif; max-width: 820px; margin: 48px auto; padding: 0 16px;}
      .card{border:1px solid #e5e7eb; border-radius:12px; padding:20px;}
      .btn{background:#111; color:#fff; padding:.6rem 1rem; border-radius:10px; text-decoration:none;}
      .btn:focus,.btn:hover{opacity:.9}
      input[type=number]{width:80px;}
    </style>
  </head>
  <body>
    <h2>🖼️ Image Tagger API</h2>
    <p>Use <a href="/docs">/docs</a> for Swagger or try the simple UI at <a class="btn" href="/ui">/ui</a>.</p>
    <div class="card">
      <h3>Quick upload</h3>
      <form action="/upload" method="post" enctype="multipart/form-data">
        <p><input type="file" name="file" accept="image/png,image/jpeg,image/webp" required></p>
        <p>Top K tags: <input type="number" name="top_k" min="1" max="20" value="5"></p>
        <p><button class="btn" type="submit">Upload</button></p>
      </form>
    </div>
  </body>
</html>"""


@app.post("/upload", response_model=TagResponse)
async def upload_image(
    file: UploadFile = File(...),
    top_k: int = Query(5, ge=1, le=20, description="How many tags to return"),
):
    try:
        content = await file.read()
        img = Image.open(io.BytesIO(content)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid image: {e}")

    # caption with BLIP
    caption = tg.caption_image(img)

    # top-K tags (ensure tagger returns ONLY the list)
    stem = Path(file.filename).stem
    tags = tg.caption_to_tags(caption, top_k=top_k)

    # optional sidecar (same content shape as JSON response)
    if WRITE_SIDECAR:
        try:
            (Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True)
            (Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text(
                TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2)
            )
        except Exception:
            # ignore filesystem errors; do not fail the request
            pass

    return TagResponse(filename=file.filename, caption=caption, tags=tags)


# -------------------- Gradio (mounted at /ui) --------------------
def _infer(image: Image.Image, top_k: int):
    """Wraps the same logic used by the API, but returns simple types
    so the schema is trivial for Gradio (avoids JSON/dict outputs)."""
    if image is None:
        return "", ""
    cap = tg.caption_image(image)
    tags = tg.caption_to_tags(cap, top_k=top_k)
    return cap, ", ".join(tags)


with gr.Blocks(title="Image Tagger UI") as demo:
    gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.")

    with gr.Row():
        with gr.Column(scale=3):
            in_img = gr.Image(type="pil", label="Upload image", height=480)
            k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)")
            submit = gr.Button("Submit", variant="primary")
            clear = gr.Button("Clear")
        with gr.Column(scale=2):
            out_cap = gr.Textbox(label="Generated Caption", lines=2)
            out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)

    submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags])
    clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags])

# mount Gradio under FastAPI
app = gr.mount_gradio_app(app, demo, path="/ui")