stephenebert commited on
Commit
a6de95a
·
verified ·
1 Parent(s): 0b1423b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -7,17 +7,18 @@ from pathlib import Path
7
  from typing import List
8
 
9
  from fastapi import FastAPI, UploadFile, File, Query
10
- from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
-
13
  from PIL import Image
14
  import gradio as gr
15
 
16
- # your tagger module (already loads BLIP and returns top-k tags + writes sidecar)
17
  import tagger as tg
18
 
19
- APP = FastAPI(title="Image Tagger API")
20
- APP.add_middleware(
 
 
21
  CORSMiddleware,
22
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True
23
  )
@@ -26,33 +27,30 @@ DATA_DIR = Path("/app/data")
26
  DATA_DIR.mkdir(parents=True, exist_ok=True)
27
 
28
 
29
- # ---------- Helpers ----------
30
  def _caption_with_tagger(img: Image.Image) -> str:
31
- """Use the BLIP model objects loaded by tagger.py to generate a caption."""
32
  try:
33
  proc = getattr(tg, "_processor")
34
  model = getattr(tg, "_model")
35
  ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
  return proc.decode(ids[0], skip_special_tokens=True)
37
  except Exception:
38
- # Caption is optional; only tags are required per your latest request.
39
  return ""
40
 
41
 
42
- # ---------- FastAPI endpoints ----------
43
- @APP.get("/healthz")
44
  def healthz():
45
  return {"ok": True}
46
 
47
- @APP.get("/", response_class=HTMLResponse)
48
  def root():
49
- # Keep this simple (no schema generation). Link to /docs and /ui.
50
  return """<!doctype html>
51
  <html>
52
  <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
53
  <body style="font-family: system-ui; max-width: 720px; margin: 40px auto">
54
  <h2>Image Tagger API</h2>
55
- <p>Upload via <a href="/docs">/docs</a> or try the UI at <a href="/ui">/ui</a>.</p>
56
  <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid; gap:12px">
57
  <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
58
  <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
@@ -61,61 +59,44 @@ def root():
61
  </body>
62
  </html>"""
63
 
64
- @APP.post("/upload")
65
  async def upload_image(
66
  file: UploadFile = File(...),
67
  top_k: int = Query(5, ge=1, le=20),
68
  ):
69
- # Read file into PIL
70
  content = await file.read()
71
  img = Image.open(io.BytesIO(content)).convert("RGB")
72
 
73
  stem = Path(file.filename).stem
74
- # Get tags (tagger will also write a JSON sidecar under tg.CAP_TAG_DIR)
75
- tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k)
76
-
77
- # Optional caption (doesn't affect tags)
78
  caption = _caption_with_tagger(img)
79
 
80
- payload = {
81
- "filename": file.filename,
82
- "caption": caption,
83
- "tags": tags,
84
- }
85
 
86
- # Also store a copy under /app/data for convenience
87
  (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
88
  img.save(DATA_DIR / file.filename)
89
 
90
  return JSONResponse(payload)
91
 
92
 
93
- # ---------- Gradio UI (mounted at /ui) ----------
94
  def _gr_predict(img: Image.Image, k: int):
95
  if img is None:
96
- return "", "", "{}"
97
  tags = tg.tag_pil_image(img.convert("RGB"), "ui_upload", top_k=int(k))
98
  caption = _caption_with_tagger(img)
99
- payload = {"filename": "ui_upload", "caption": caption, "tags": tags}
100
- return caption, ", ".join(tags), json.dumps(payload, indent=2)
101
-
102
- with gr.Blocks(css="""
103
- .gradio-container {max-width: 860px !important}
104
- """) as demo:
105
- gr.Markdown("## Image Tagger — UI\nUpload an image, choose `top_k`, and get tags.")
106
 
 
 
107
  with gr.Row():
108
  in_img = gr.Image(type="pil", label="Image")
109
  k = gr.Slider(1, 20, value=5, step=1, label="Top-k tags")
 
 
 
 
110
 
111
- run = gr.Button("Tag Image")
112
-
113
- with gr.Row():
114
- out_caption = gr.Textbox(label="Caption", lines=2)
115
- with gr.Row():
116
- out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
117
- out_json = gr.Textbox(label="Raw JSON", lines=10)
118
-
119
- run.click(_gr_predict, inputs=[in_img, k], outputs=[out_caption, out_tags, out_json])
120
-
121
- APP = gr.mount_gradio_app(APP, demo, path="/ui")
 
7
  from typing import List
8
 
9
  from fastapi import FastAPI, UploadFile, File, Query
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import HTMLResponse, JSONResponse
12
  from PIL import Image
13
  import gradio as gr
14
 
15
+ # Uses your existing tagger.py (BLIP loaded there; returns top-k tags and writes sidecar)
16
  import tagger as tg
17
 
18
+ # ---------- FastAPI base app ----------
19
+ app = FastAPI(title="Image Tagger API")
20
+
21
+ app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True
24
  )
 
27
  DATA_DIR.mkdir(parents=True, exist_ok=True)
28
 
29
 
 
30
  def _caption_with_tagger(img: Image.Image) -> str:
31
+ """Optional caption via BLIP objects already loaded in tagger.py."""
32
  try:
33
  proc = getattr(tg, "_processor")
34
  model = getattr(tg, "_model")
35
  ids = model.generate(**proc(images=img, return_tensors="pt"), max_length=30)
36
  return proc.decode(ids[0], skip_special_tokens=True)
37
  except Exception:
 
38
  return ""
39
 
40
 
41
+ # ---------- API endpoints ----------
42
+ @app.get("/healthz")
43
  def healthz():
44
  return {"ok": True}
45
 
46
+ @app.get("/", response_class=HTMLResponse)
47
  def root():
 
48
  return """<!doctype html>
49
  <html>
50
  <head><meta charset="utf-8" /><title>Image Tagger API</title></head>
51
  <body style="font-family: system-ui; max-width: 720px; margin: 40px auto">
52
  <h2>Image Tagger API</h2>
53
+ <p>Try the Swagger docs at <a href="/docs">/docs</a> or the UI at <a href="/ui">/ui</a>.</p>
54
  <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid; gap:12px">
55
  <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
56
  <label>top_k: <input type="number" name="top_k" value="5" min="1" max="20" /></label>
 
59
  </body>
60
  </html>"""
61
 
62
+ @app.post("/upload")
63
  async def upload_image(
64
  file: UploadFile = File(...),
65
  top_k: int = Query(5, ge=1, le=20),
66
  ):
 
67
  content = await file.read()
68
  img = Image.open(io.BytesIO(content)).convert("RGB")
69
 
70
  stem = Path(file.filename).stem
71
+ tags: List[str] = tg.tag_pil_image(img, stem, top_k=top_k) # tagger writes sidecar in /app/data (per your tagger)
 
 
 
72
  caption = _caption_with_tagger(img)
73
 
74
+ payload = {"filename": file.filename, "caption": caption, "tags": tags}
 
 
 
 
75
 
76
+ # Save a copy for convenience
77
  (DATA_DIR / f"{stem}.json").write_text(json.dumps(payload, indent=2))
78
  img.save(DATA_DIR / file.filename)
79
 
80
  return JSONResponse(payload)
81
 
82
 
83
+ # ---------- Gradio UI at /ui ----------
84
  def _gr_predict(img: Image.Image, k: int):
85
  if img is None:
86
+ return "", ""
87
  tags = tg.tag_pil_image(img.convert("RGB"), "ui_upload", top_k=int(k))
88
  caption = _caption_with_tagger(img)
89
+ return caption, ", ".join(tags)
 
 
 
 
 
 
90
 
91
+ with gr.Blocks(css=".gradio-container {max-width: 860px !important}") as demo:
92
+ gr.Markdown("## Image Tagger — UI")
93
  with gr.Row():
94
  in_img = gr.Image(type="pil", label="Image")
95
  k = gr.Slider(1, 20, value=5, step=1, label="Top-k tags")
96
+ run = gr.Button("Tag image")
97
+ out_caption = gr.Textbox(label="Caption", lines=2)
98
+ out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
99
+ run.click(_gr_predict, inputs=[in_img, k], outputs=[out_caption, out_tags])
100
 
101
+ # IMPORTANT: export lowercase `app` (Uvicorn expects app:app)
102
+ app = gr.mount_gradio_app(app, demo, path="/ui")