stephenebert commited on
Commit
d3e85eb
·
verified ·
1 Parent(s): 0e857c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -106
app.py CHANGED
@@ -1,42 +1,36 @@
1
- # app.py
2
  from __future__ import annotations
3
 
4
  import io
5
- import json
6
  from pathlib import Path
7
  from typing import List
8
 
9
- from fastapi import FastAPI, UploadFile, File, Query
10
- from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.responses import HTMLResponse, JSONResponse
12
- from PIL import Image
13
  import gradio as gr
 
 
 
 
14
 
15
- # Your BLIP + tagging logic lives here
16
  import tagger as tg
17
 
18
  # -------------------- FastAPI --------------------
19
- app = FastAPI(title="Image Tagger API")
20
-
21
- app.add_middleware(
22
- CORSMiddleware,
23
- allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True
24
  )
25
 
26
- DATA_DIR = Path("/app/data")
27
- DATA_DIR.mkdir(parents=True, exist_ok=True)
28
 
 
 
 
 
29
 
30
- def _caption_with_blip(img: Image.Image) -> str:
31
- """Produce a caption using the BLIP objects that tagger.py already loaded."""
32
- try:
33
- proc = getattr(tg, "_processor")
34
- model = getattr(tg, "_model")
35
- ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
- return proc.decode(ids[0], skip_special_tokens=True)
37
- except Exception:
38
- # If anything goes wrong, don't break the UI — just return empty caption
39
- return ""
40
 
41
 
42
  @app.get("/healthz")
@@ -46,103 +40,93 @@ def healthz():
46
 
47
  @app.get("/", response_class=HTMLResponse)
48
  def root():
49
- # tiny landing page for folks who hit the root
50
- return """<!doctype html>
51
  <html>
52
- <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
53
- <body style="font-family:system-ui;max-width:760px;margin:40px auto">
54
- <h2>Image Tagger API</h2>
55
- <p>Try the UI at <a href="/ui">/ui</a> or the API docs at <a href="/docs">/docs</a>.</p>
56
- <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid;gap:12px">
57
- <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
58
- <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
59
- <button type="submit">Upload</button>
60
- </form>
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  </body>
62
  </html>"""
63
 
64
 
65
- @app.post("/upload")
66
  async def upload_image(
67
  file: UploadFile = File(...),
68
- top_k: int = Query(5, ge=1, le=20),
69
  ):
70
- content = await file.read()
71
- img = Image.open(io.BytesIO(content)).convert("RGB")
72
-
73
- stem = Path(file.filename).stem
74
-
75
- # Use your tagger (returns top-k tags; no POS filters)
76
- tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k)
77
- caption = _caption_with_blip(img)
78
-
79
- payload = {"filename": file.filename, "caption": caption, "tags": tags}
80
-
81
- # Persist a sidecar + the uploaded image (handy for debugging)
82
- (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
83
- img.save(DATA_DIR / file.filename)
84
-
85
- return JSONResponse(payload)
86
-
87
-
88
- # -------------------- Gradio UI (/ui) --------------------
89
- def _predict_ui(img: Image.Image | None, k: int):
90
- if img is None:
91
- return "", [] # caption, table rows
92
 
93
- img = img.convert("RGB")
94
- caption = _caption_with_blip(img)
95
- tags = tg.tag_pil_image(img, "ui_upload", top_k=int(k))
96
 
97
- # Build a ranked table: [["#1", "lion"], ["#2", "rock"], ...]
98
- rows = [[f"{i+1}", tag] for i, tag in enumerate(tags)]
99
- return caption, rows
 
 
 
 
 
 
 
 
 
 
 
100
 
 
101
 
102
- CSS = """
103
- .gradio-container { max-width: 1200px !important; }
104
- .caption-box textarea { font-size: 16px !important; }
105
- .rank-table thead th { font-weight: 600; }
106
- """
107
 
108
- with gr.Blocks(css=CSS, theme=gr.themes.Default(primary_hue="blue")) as demo:
109
- gr.Markdown("### 🔎 Image Tagger (BLIP ➜ Caption + Tags)")
 
 
 
 
 
 
 
110
 
111
- with gr.Row(equal_height=True):
112
- # LEFT: big image
113
- with gr.Column(scale=7):
114
- in_img = gr.Image(type="pil", label="Upload Image", height=540)
115
 
116
- # RIGHT: caption + ranked list
117
- with gr.Column(scale=5):
118
- gr.Markdown("**BLIP Generated Caption**")
119
- out_caption = gr.Textbox(
120
- label=None, lines=2, elem_classes=["caption-box"], interactive=False
121
- )
122
 
123
- gr.Markdown("**Top-k Tags (ranked)**")
124
- out_table = gr.Dataframe(
125
- headers=["Rank", "Tag"],
126
- datatype=["str", "str"],
127
- col_count=(2, "fixed"),
128
- row_count=(0, "dynamic"),
129
- wrap=True,
130
- interactive=False,
131
- elem_classes=["rank-table"],
132
- height=380,
133
- )
134
-
135
- with gr.Row():
136
- k = gr.Slider(1, 20, value=5, step=1, label="Number of Tags (k)")
137
  with gr.Row():
138
- clear_btn = gr.ClearButton(
139
- [in_img, out_caption, out_table],
140
- value="Clear",
141
- size="sm",
142
- )
143
- submit_btn = gr.Button("Submit", variant="primary")
144
-
145
- submit_btn.click(_predict_ui, inputs=[in_img, k], outputs=[out_caption, out_table])
146
-
147
- # Export lowercase `app` for Uvicorn ("app:app")
 
 
 
148
  app = gr.mount_gradio_app(app, demo, path="/ui")
 
 
1
  from __future__ import annotations
2
 
3
  import io
4
+ import os
5
  from pathlib import Path
6
  from typing import List
7
 
 
 
 
 
8
  import gradio as gr
9
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
10
+ from fastapi.responses import HTMLResponse
11
+ from pydantic import BaseModel, Field
12
+ from PIL import Image
13
 
 
14
  import tagger as tg
15
 
16
  # -------------------- FastAPI --------------------
17
+ app = FastAPI(
18
+ title="Image Tagger API",
19
+ version="1.0.0",
20
+ description="Generate a caption with BLIP, then return top-K tags derived from that caption.",
 
21
  )
22
 
23
+ WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0"
 
24
 
25
+ class TagResponse(BaseModel):
26
+ filename: str = Field(..., examples=["photo.jpg"])
27
+ caption: str = Field(..., examples=["a lion rests on a rock in the wild"])
28
+ tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]])
29
 
30
+
31
+ @app.on_event("startup")
32
+ def _load_once() -> None:
33
+ tg.init_models()
 
 
 
 
 
 
34
 
35
 
36
  @app.get("/healthz")
 
40
 
41
  @app.get("/", response_class=HTMLResponse)
42
  def root():
43
+ return """
44
+ <!doctype html>
45
  <html>
46
+ <head>
47
+ <meta charset="utf-8" />
48
+ <title>Image Tagger API</title>
49
+ <style>
50
+ body{font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, sans-serif; max-width: 820px; margin: 48px auto; padding: 0 16px;}
51
+ .card{border:1px solid #e5e7eb; border-radius:12px; padding:20px;}
52
+ .btn{background:#111; color:#fff; padding:.6rem 1rem; border-radius:10px; text-decoration:none;}
53
+ .btn:focus,.btn:hover{opacity:.9}
54
+ input[type=number]{width:80px;}
55
+ </style>
56
+ </head>
57
+ <body>
58
+ <h2>🖼️ Image Tagger API</h2>
59
+ <p>Use <a href="/docs">/docs</a> for Swagger or try the simple UI at <a class="btn" href="/ui">/ui</a>.</p>
60
+ <div class="card">
61
+ <h3>Quick upload</h3>
62
+ <form action="/upload" method="post" enctype="multipart/form-data">
63
+ <p><input type="file" name="file" accept="image/png,image/jpeg,image/webp" required></p>
64
+ <p>Top K tags: <input type="number" name="top_k" min="1" max="20" value="5"></p>
65
+ <p><button class="btn" type="submit">Upload</button></p>
66
+ </form>
67
+ </div>
68
  </body>
69
  </html>"""
70
 
71
 
72
+ @app.post("/upload", response_model=TagResponse)
73
  async def upload_image(
74
  file: UploadFile = File(...),
75
+ top_k: int = Query(5, ge=1, le=20, description="How many tags to return"),
76
  ):
77
+ try:
78
+ content = await file.read()
79
+ img = Image.open(io.BytesIO(content)).convert("RGB")
80
+ except Exception as e:
81
+ raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # caption with BLIP
84
+ caption = tg.caption_image(img)
 
85
 
86
+ # top-K tags (ensure tagger returns ONLY the list)
87
+ stem = Path(file.filename).stem
88
+ tags = tg.caption_to_tags(caption, top_k=top_k)
89
+
90
+ # optional sidecar (same content shape as JSON response)
91
+ if WRITE_SIDECAR:
92
+ try:
93
+ (Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True)
94
+ (Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text(
95
+ TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2)
96
+ )
97
+ except Exception:
98
+ # ignore filesystem errors; do not fail the request
99
+ pass
100
 
101
+ return TagResponse(filename=file.filename, caption=caption, tags=tags)
102
 
 
 
 
 
 
103
 
104
+ # -------------------- Gradio (mounted at /ui) --------------------
105
+ def _infer(image: Image.Image, top_k: int):
106
+ """Wraps the same logic used by the API, but returns simple types
107
+ so the schema is trivial for Gradio (avoids JSON/dict outputs)."""
108
+ if image is None:
109
+ return "", ""
110
+ cap = tg.caption_image(image)
111
+ tags = tg.caption_to_tags(cap, top_k=top_k)
112
+ return cap, ", ".join(tags)
113
 
 
 
 
 
114
 
115
+ with gr.Blocks(title="Image Tagger UI") as demo:
116
+ gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.")
 
 
 
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with gr.Row():
119
+ with gr.Column(scale=3):
120
+ in_img = gr.Image(type="pil", label="Upload image", height=480)
121
+ k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)")
122
+ submit = gr.Button("Submit", variant="primary")
123
+ clear = gr.Button("Clear")
124
+ with gr.Column(scale=2):
125
+ out_cap = gr.Textbox(label="Generated Caption", lines=2)
126
+ out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
127
+
128
+ submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags])
129
+ clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags])
130
+
131
+ # mount Gradio under FastAPI
132
  app = gr.mount_gradio_app(app, demo, path="/ui")