stephenebert commited on
Commit
cb8b918
·
verified ·
1 Parent(s): a6de95a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -27
app.py CHANGED
@@ -12,10 +12,10 @@ from fastapi.responses import HTMLResponse, JSONResponse
12
  from PIL import Image
13
  import gradio as gr
14
 
15
- # Uses your existing tagger.py (BLIP loaded there; returns top-k tags and writes sidecar)
16
  import tagger as tg
17
 
18
- # ---------- FastAPI base app ----------
19
  app = FastAPI(title="Image Tagger API")
20
 
21
  app.add_middleware(
@@ -27,31 +27,33 @@ DATA_DIR = Path("/app/data")
27
  DATA_DIR.mkdir(parents=True, exist_ok=True)
28
 
29
 
30
- def _caption_with_tagger(img: Image.Image) -> str:
31
- """Optional caption via BLIP objects already loaded in tagger.py."""
32
  try:
33
  proc = getattr(tg, "_processor")
34
  model = getattr(tg, "_model")
35
  ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
  return proc.decode(ids[0], skip_special_tokens=True)
37
  except Exception:
 
38
  return ""
39
 
40
 
41
- # ---------- API endpoints ----------
42
  @app.get("/healthz")
43
  def healthz():
44
  return {"ok": True}
45
 
 
46
  @app.get("/", response_class=HTMLResponse)
47
  def root():
 
48
  return """<!doctype html>
49
  <html>
50
  <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
51
- <body style="font-family: system-ui; max-width: 720px; margin: 40px auto">
52
  <h2>Image Tagger API</h2>
53
- <p>Try the Swagger docs at <a href="/docs">/docs</a> or the UI at <a href="/ui">/ui</a>.</p>
54
- <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid; gap:12px">
55
  <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
56
  <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
57
  <button type="submit">Upload</button>
@@ -59,6 +61,7 @@ def root():
59
  </body>
60
  </html>"""
61
 
 
62
  @app.post("/upload")
63
  async def upload_image(
64
  file: UploadFile = File(...),
@@ -68,35 +71,78 @@ async def upload_image(
68
  img = Image.open(io.BytesIO(content)).convert("RGB")
69
 
70
  stem = Path(file.filename).stem
71
- tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k) # tagger writes sidecar in /app/data (per your tagger)
72
- caption = _caption_with_tagger(img)
 
 
73
 
74
  payload = {"filename": file.filename, "caption": caption, "tags": tags}
75
 
76
- # Save a copy for convenience
77
  (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
78
  img.save(DATA_DIR / file.filename)
79
 
80
  return JSONResponse(payload)
81
 
82
 
83
- # ---------- Gradio UI at /ui ----------
84
- def _gr_predict(img: Image.Image, k: int):
85
  if img is None:
86
- return "", ""
87
- tags = tg.tag_pil_image(img.convert("RGB"), "ui_upload", top_k=int(k))
88
- caption = _caption_with_tagger(img)
89
- return caption, ", ".join(tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- with gr.Blocks(css=".gradio-container {max-width: 860px !important}") as demo:
92
- gr.Markdown("## Image Tagger — UI")
93
  with gr.Row():
94
- in_img = gr.Image(type="pil", label="Image")
95
- k = gr.Slider(1, 20, value=5, step=1, label="Top-k tags")
96
- run = gr.Button("Tag image")
97
- out_caption = gr.Textbox(label="Caption", lines=2)
98
- out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
99
- run.click(_gr_predict, inputs=[in_img, k], outputs=[out_caption, out_tags])
100
-
101
- # IMPORTANT: export lowercase `app` (Uvicorn expects app:app)
 
 
 
 
102
  app = gr.mount_gradio_app(app, demo, path="/ui")
 
12
  from PIL import Image
13
  import gradio as gr
14
 
15
+ # Your BLIP + tagging logic lives here
16
  import tagger as tg
17
 
18
+ # -------------------- FastAPI --------------------
19
  app = FastAPI(title="Image Tagger API")
20
 
21
  app.add_middleware(
 
27
  DATA_DIR.mkdir(parents=True, exist_ok=True)
28
 
29
 
30
+ def _caption_with_blip(img: Image.Image) -> str:
31
+ """Produce a caption using the BLIP objects that tagger.py already loaded."""
32
  try:
33
  proc = getattr(tg, "_processor")
34
  model = getattr(tg, "_model")
35
  ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
  return proc.decode(ids[0], skip_special_tokens=True)
37
  except Exception:
38
+ # If anything goes wrong, don't break the UI — just return empty caption
39
  return ""
40
 
41
 
 
42
  @app.get("/healthz")
43
  def healthz():
44
  return {"ok": True}
45
 
46
+
47
  @app.get("/", response_class=HTMLResponse)
48
  def root():
49
+ # tiny landing page for folks who hit the root
50
  return """<!doctype html>
51
  <html>
52
  <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
53
+ <body style="font-family:system-ui;max-width:760px;margin:40px auto">
54
  <h2>Image Tagger API</h2>
55
+ <p>Try the UI at <a href="/ui">/ui</a> or the API docs at <a href="/docs">/docs</a>.</p>
56
+ <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid;gap:12px">
57
  <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
58
  <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
59
  <button type="submit">Upload</button>
 
61
  </body>
62
  </html>"""
63
 
64
+
65
  @app.post("/upload")
66
  async def upload_image(
67
  file: UploadFile = File(...),
 
71
  img = Image.open(io.BytesIO(content)).convert("RGB")
72
 
73
  stem = Path(file.filename).stem
74
+
75
+ # Use your tagger (returns top-k tags; no POS filters)
76
+ tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k)
77
+ caption = _caption_with_blip(img)
78
 
79
  payload = {"filename": file.filename, "caption": caption, "tags": tags}
80
 
81
+ # Persist a sidecar + the uploaded image (handy for debugging)
82
  (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
83
  img.save(DATA_DIR / file.filename)
84
 
85
  return JSONResponse(payload)
86
 
87
 
88
+ # -------------------- Gradio UI (/ui) --------------------
89
+ def _predict_ui(img: Image.Image | None, k: int):
90
  if img is None:
91
+ return "", [] # caption, table rows
92
+
93
+ img = img.convert("RGB")
94
+ caption = _caption_with_blip(img)
95
+ tags = tg.tag_pil_image(img, "ui_upload", top_k=int(k))
96
+
97
+ # Build a ranked table: [["#1", "lion"], ["#2", "rock"], ...]
98
+ rows = [[f"{i+1}", tag] for i, tag in enumerate(tags)]
99
+ return caption, rows
100
+
101
+
102
+ CSS = """
103
+ .gradio-container { max-width: 1200px !important; }
104
+ .caption-box textarea { font-size: 16px !important; }
105
+ .rank-table thead th { font-weight: 600; }
106
+ """
107
+
108
+ with gr.Blocks(css=CSS, theme=gr.themes.Default(primary_hue="blue")) as demo:
109
+ gr.Markdown("### 🔎 Image Tagger (BLIP ➜ Caption + Tags)")
110
+
111
+ with gr.Row(equal_height=True):
112
+ # LEFT: big image
113
+ with gr.Column(scale=7):
114
+ in_img = gr.Image(type="pil", label="Upload Image", height=540)
115
+
116
+ # RIGHT: caption + ranked list
117
+ with gr.Column(scale=5):
118
+ gr.Markdown("**BLIP Generated Caption**")
119
+ out_caption = gr.Textbox(
120
+ label=None, lines=2, elem_classes=["caption-box"], interactive=False
121
+ )
122
+
123
+ gr.Markdown("**Top-k Tags (ranked)**")
124
+ out_table = gr.Dataframe(
125
+ headers=["Rank", "Tag"],
126
+ datatype=["str", "str"],
127
+ col_count=(2, "fixed"),
128
+ row_count=(0, "dynamic"),
129
+ wrap=True,
130
+ interactive=False,
131
+ elem_classes=["rank-table"],
132
+ height=380,
133
+ )
134
 
 
 
135
  with gr.Row():
136
+ k = gr.Slider(1, 20, value=5, step=1, label="Number of Tags (k)")
137
+ with gr.Row():
138
+ clear_btn = gr.ClearButton(
139
+ [in_img, out_caption, out_table],
140
+ value="Clear",
141
+ size="sm",
142
+ )
143
+ submit_btn = gr.Button("Submit", variant="primary")
144
+
145
+ submit_btn.click(_predict_ui, inputs=[in_img, k], outputs=[out_caption, out_table])
146
+
147
+ # Export lowercase `app` for Uvicorn ("app:app")
148
  app = gr.mount_gradio_app(app, demo, path="/ui")