stephenebert commited on
Commit
7d79380
·
verified ·
1 Parent(s): ca79a30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -21
app.py CHANGED
@@ -1,51 +1,58 @@
1
  from __future__ import annotations
 
 
 
 
 
 
 
 
2
  from fastapi import FastAPI, File, HTTPException, Query, UploadFile
3
- from fastapi.responses import HTMLResponse, JSONResponse
4
  from pydantic import BaseModel
5
  from typing import List
6
  from pathlib import Path
7
  from PIL import Image
8
  import io
9
 
10
- from tagger import tag_pil_image # returns (caption, tags)
 
11
 
12
- app = FastAPI(title="Image Tagger API", version="0.3.0")
13
 
 
 
14
  class TagOut(BaseModel):
15
  filename: str
16
  caption: str
17
  tags: List[str]
18
 
19
- @app.get("/", response_class=HTMLResponse)
20
- def home() -> str:
21
- return """
22
- <html><head><meta charset="utf-8"><title>Image Tagger API</title></head>
23
- <body style="font-family:system-ui,-apple-system,Segoe UI,Roboto,Ubuntu,sans-serif;max-width:720px;margin:40px auto;padding:0 16px">
24
- <h2>Image Tagger API</h2>
25
- <p>Use <a href="/docs">/docs</a> for Swagger, or upload here:</p>
26
- <form action="/upload" method="post" enctype="multipart/form-data" style="display:grid;gap:12px">
27
- <input type="file" name="file" accept="image/png,image/jpeg,image/webp" required />
28
- <label>top_k:
29
- <input type="number" name="top_k" value="5" min="1" max="20">
30
- </label>
31
- <button type="submit" style="padding:.6rem 1rem;border-radius:10px;border:1px solid #ddd;background:#111;color:#fff">Upload</button>
32
- </form>
33
- </body></html>
34
- """
35
 
36
  @app.get("/healthz")
37
  def healthz():
38
  return {"ok": True}
39
 
 
 
40
  @app.post("/upload", response_model=TagOut)
41
  async def upload(
42
  file: UploadFile = File(...),
43
  top_k: int = Query(5, ge=1, le=20, description="Max number of tags"),
44
  ):
 
45
  if file.content_type not in {"image/png", "image/jpeg", "image/webp"}:
46
- raise HTTPException(status_code=415, detail="Only PNG, JPEG, or WebP images are supported")
 
 
47
 
48
- # Load image
49
  try:
50
  data = await file.read()
51
  img = Image.open(io.BytesIO(data)).convert("RGB")
@@ -54,9 +61,42 @@ async def upload(
54
 
55
  stem = Path(file.filename).stem or "upload"
56
 
 
57
  try:
58
  caption, tags = tag_pil_image(img, stem, top_k=top_k)
59
  except Exception as e:
60
  raise HTTPException(status_code=500, detail=f"Tagging failed: {e}")
61
 
62
  return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
+ # FastAPI REST API + a Gradio UI mounted at /ui
4
+ # Endpoints:
5
+ # GET /healthz
6
+ # POST /upload (multipart image -> {filename, caption, tags})
7
+ # UI:
8
+ # /ui (upload image, choose top_k, see caption + tags)
9
+
10
  from fastapi import FastAPI, File, HTTPException, Query, UploadFile
11
+ from fastapi.responses import JSONResponse, RedirectResponse
12
  from pydantic import BaseModel
13
  from typing import List
14
  from pathlib import Path
15
  from PIL import Image
16
  import io
17
 
18
+ import gradio as gr
19
+ from tagger import tag_pil_image # returns (caption: str, tags: List[str])
20
 
21
+ app = FastAPI(title="Image Tagger API", version="0.4.0")
22
 
23
+
24
+ # ---------- Pydantic model for OpenAPI ----------
25
  class TagOut(BaseModel):
26
  filename: str
27
  caption: str
28
  tags: List[str]
29
 
30
+
31
+ # ---------- Basic routes ----------
32
+ @app.get("/", include_in_schema=False)
33
+ def root_redirect():
34
+ # Send users to the nicer UI
35
+ return RedirectResponse(url="/ui", status_code=302)
36
+
 
 
 
 
 
 
 
 
 
37
 
38
  @app.get("/healthz")
39
  def healthz():
40
  return {"ok": True}
41
 
42
+
43
+ # ---------- REST endpoint ----------
44
  @app.post("/upload", response_model=TagOut)
45
  async def upload(
46
  file: UploadFile = File(...),
47
  top_k: int = Query(5, ge=1, le=20, description="Max number of tags"),
48
  ):
49
+ # MIME guard
50
  if file.content_type not in {"image/png", "image/jpeg", "image/webp"}:
51
+ raise HTTPException(
52
+ status_code=415, detail="Only PNG, JPEG, or WebP images are supported"
53
+ )
54
 
55
+ # Decode image
56
  try:
57
  data = await file.read()
58
  img = Image.open(io.BytesIO(data)).convert("RGB")
 
61
 
62
  stem = Path(file.filename).stem or "upload"
63
 
64
+ # Tag
65
  try:
66
  caption, tags = tag_pil_image(img, stem, top_k=top_k)
67
  except Exception as e:
68
  raise HTTPException(status_code=500, detail=f"Tagging failed: {e}")
69
 
70
  return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags})
71
+
72
+
73
+ # ---------- Gradio UI (mounted at /ui) ----------
74
+ def _ui_tag(image: Image.Image, top_k: int):
75
+ if image is None:
76
+ return "", []
77
+ caption, tags = tag_pil_image(image.convert("RGB"), "upload", top_k=top_k)
78
+ return caption, tags
79
+
80
+
81
+ with gr.Blocks(title="Image Tagger", analytics_enabled=False) as gr_app:
82
+ gr.Markdown(
83
+ "## 🏷️ Image Tagger\n"
84
+ "Upload an image to get a caption and top-k tags. "
85
+ "Programmatic API is available at **/docs**."
86
+ )
87
+ with gr.Row():
88
+ with gr.Column(scale=1):
89
+ inp = gr.Image(
90
+ type="pil", label="Upload image", sources=["upload"], height=360
91
+ )
92
+ k = gr.Slider(1, 20, value=5, step=1, label="Top-k tags")
93
+ btn = gr.Button("Tag it")
94
+ with gr.Column(scale=1):
95
+ cap = gr.Textbox(label="Caption", lines=2)
96
+ tags = gr.Tags(label="Tags")
97
+
98
+ btn.click(_ui_tag, inputs=[inp, k], outputs=[cap, tags])
99
+
100
+ # Mount Gradio on the same FastAPI app
101
+ app = gr.mount_gradio_app(app, gr_app, path="/ui")
102
+