stephenebert commited on
Commit
3452c0c
·
verified ·
1 Parent(s): 7b6684e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -68
app.py CHANGED
@@ -1,86 +1,122 @@
 
1
  from __future__ import annotations
2
 
3
- # FastAPI REST API + Gradio UI at /
4
- # Endpoints:
5
- # GET /healthz
6
- # POST /upload -> {filename, caption, tags}
7
- # UI:
8
- # / (upload image, choose top_k, see caption + tags)
9
- # Docs:
10
- # /docs
11
-
12
- from fastapi import FastAPI, File, HTTPException, Query, UploadFile
13
- from fastapi.responses import JSONResponse
14
- from pydantic import BaseModel
15
- from typing import List
16
- from pathlib import Path
17
- from PIL import Image
18
  import io
 
 
 
 
 
 
 
19
 
 
20
  import gradio as gr
21
- from tagger import tag_pil_image # returns (caption: str, tags: List[str])
22
 
23
- app = FastAPI(title="Image Tagger API", version="0.4.3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # ---------- Pydantic model ----------
26
- class TagOut(BaseModel):
27
- filename: str
28
- caption: str
29
- tags: List[str]
30
 
31
- # ---------- Health ----------
32
- @app.get("/healthz")
33
  def healthz():
34
  return {"ok": True}
35
 
36
- # ---------- REST endpoint ----------
37
- @app.post("/upload", response_model=TagOut)
38
- async def upload(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  file: UploadFile = File(...),
40
- top_k: int = Query(5, ge=1, le=20, description="Max number of tags"),
41
  ):
42
- if file.content_type not in {"image/png", "image/jpeg", "image/webp"}:
43
- raise HTTPException(
44
- status_code=415, detail="Only PNG, JPEG, or WebP images are supported"
45
- )
46
 
47
- try:
48
- data = await file.read()
49
- img = Image.open(io.BytesIO(data)).convert("RGB")
50
- except Exception:
51
- raise HTTPException(status_code=400, detail="Could not decode image")
52
 
53
- stem = Path(file.filename).stem or "upload"
 
54
 
55
- try:
56
- caption, tags = tag_pil_image(img, stem, top_k=top_k)
57
- except Exception as e:
58
- raise HTTPException(status_code=500, detail=f"Tagging failed: {e}")
59
-
60
- return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags})
61
-
62
- # ---------- Gradio UI at root ----------
63
- def _ui_tag(image: Image.Image, top_k: int):
64
- if image is None:
65
- return "", ""
66
- caption, tags = tag_pil_image(image.convert("RGB"), "upload", top_k=top_k)
67
- return caption, ", ".join(tags)
68
-
69
- demo = gr.Interface(
70
- fn=_ui_tag,
71
- inputs=[
72
- gr.Image(type="pil", label="Upload image"),
73
- gr.Slider(1, 20, value=5, step=1, label="Top-k tags"),
74
- ],
75
- outputs=[
76
- gr.Textbox(label="Caption", lines=2),
77
- gr.Textbox(label="Tags (comma-separated)", lines=2),
78
- ],
79
- flagging_mode="never",
80
- title="Image Tagger",
81
- description="Upload an image to get a caption and top-k tags. Programmatic API at /docs.",
82
- )
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Mount Gradio on the same FastAPI app at root (/) to avoid redirects
85
- app = gr.mount_gradio_app(app, demo, path="/")
86
 
 
1
+ # app.py
2
  from __future__ import annotations
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import io
5
+ import json
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ from fastapi import FastAPI, UploadFile, File, Query
10
+ from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
 
13
+ from PIL import Image
14
  import gradio as gr
 
15
 
16
+ # your tagger module (already loads BLIP and returns top-k tags + writes sidecar)
17
+ import tagger as tg
18
+
19
+ APP = FastAPI(title="Image Tagger API")
20
+ APP.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True
23
+ )
24
+
25
+ DATA_DIR = Path("/app/data")
26
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
27
+
28
+
29
+ # ---------- Helpers ----------
30
+ def _caption_with_tagger(img: Image.Image) -> str:
31
+ """Use the BLIP model objects loaded by tagger.py to generate a caption."""
32
+ try:
33
+ proc = getattr(tg, "_processor")
34
+ model = getattr(tg, "_model")
35
+ ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
+ return proc.decode(ids[0], skip_special_tokens=True)
37
+ except Exception:
38
+ # Caption is optional; only tags are required per your latest request.
39
+ return ""
40
 
 
 
 
 
 
41
 
42
+ # ---------- FastAPI endpoints ----------
43
+ @APP.get("/healthz")
44
  def healthz():
45
  return {"ok": True}
46
 
47
+ @APP.get("/", response_class=HTMLResponse)
48
+ def root():
49
+ # Keep this simple (no schema generation). Link to /docs and /ui.
50
+ return """<!doctype html>
51
+ <html>
52
+ <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
53
+ <body style="font-family: system-ui; max-width: 720px; margin: 40px auto">
54
+ <h2>Image Tagger API</h2>
55
+ <p>Upload via <a href="/docs">/docs</a> or try the UI at <a href="/ui">/ui</a>.</p>
56
+ <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid; gap:12px">
57
+ <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
58
+ <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
59
+ <button type="submit">Upload</button>
60
+ </form>
61
+ </body>
62
+ </html>"""
63
+
64
+ @APP.post("/upload")
65
+ async def upload_image(
66
  file: UploadFile = File(...),
67
+ top_k: int = Query(5, ge=1, le=20),
68
  ):
69
+ # Read file into PIL
70
+ content = await file.read()
71
+ img = Image.open(io.BytesIO(content)).convert("RGB")
 
72
 
73
+ stem = Path(file.filename).stem
74
+ # Get tags (tagger will also write a JSON sidecar under tg.CAP_TAG_DIR)
75
+ tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k)
 
 
76
 
77
+ # Optional caption (doesn't affect tags)
78
+ caption = _caption_with_tagger(img)
79
 
80
+ payload = {
81
+ "filename": file.filename,
82
+ "caption": caption,
83
+ "tags": tags,
84
+ }
85
+
86
+ # Also store a copy under /app/data for convenience
87
+ (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
88
+ img.save(DATA_DIR / file.filename)
89
+
90
+ return JSONResponse(payload)
91
+
92
+
93
+ # ---------- Gradio UI (mounted at /ui) ----------
94
+ def _gr_predict(img: Image.Image, k: int):
95
+ if img is None:
96
+ return "", "", "{}"
97
+ tags = tg.tag_pil_image(img.convert("RGB"), "ui_upload", top_k=int(k))
98
+ caption = _caption_with_tagger(img)
99
+ payload = {"filename": "ui_upload", "caption": caption, "tags": tags}
100
+ return caption, ", ".join(tags), json.dumps(payload, indent=2)
101
+
102
+ with gr.Blocks(css="""
103
+ .gradio-container {max-width: 860px !important}
104
+ """) as demo:
105
+ gr.Markdown("## Image Tagger — UI\nUpload an image, choose `top_k`, and get tags.")
106
+
107
+ with gr.Row():
108
+ in_img = gr.Image(type="pil", label="Image")
109
+ k = gr.Slider(1, 20, value=5, step=1, label="Top-k tags")
110
+
111
+ run = gr.Button("Tag Image")
112
+
113
+ with gr.Row():
114
+ out_caption = gr.Textbox(label="Caption", lines=2)
115
+ with gr.Row():
116
+ out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
117
+ out_json = gr.Textbox(label="Raw JSON", lines=10)
118
+
119
+ run.click(_gr_predict, inputs=[in_img, k], outputs=[out_caption, out_tags, out_json])
120
 
121
+ APP = gr.mount_gradio_app(APP, demo, path="/ui")
 
122