yenslife commited on
Commit
0066f5e
·
1 Parent(s): 192d641

feat: improve demo model selection workflow

Browse files

新增 demo 頁面的模型下拉選單,支援在同一張圖片上切換不同模型重新推論。

調整圖片預覽縮放與顯示尺寸,並放寬 multipart 上傳大小限制,避免大圖被 1MB 預設限制擋下。

Files changed (2) hide show
  1. app.py +106 -17
  2. templates/demo.html +66 -7
app.py CHANGED
@@ -2,15 +2,17 @@ import base64
2
  from contextlib import asynccontextmanager
3
  from io import BytesIO
4
  from pathlib import Path
 
5
 
6
- from fastapi import FastAPI, File, HTTPException, Request, UploadFile
7
  from fastapi.responses import HTMLResponse
8
  from fastapi.templating import Jinja2Templates
9
  from PIL import Image, UnidentifiedImageError
10
 
11
- from model_service import get_model_config, get_model_service
12
 
13
  BASE_DIR = Path(__file__).resolve().parent
 
14
  templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
15
  ACTIVE_MODEL_CONFIG = get_model_config()
16
 
@@ -39,30 +41,112 @@ def _build_demo_context(**overrides):
39
  "confidence": "-",
40
  "acc": "-",
41
  "error": None,
 
 
 
 
 
 
 
 
 
42
  }
43
  context.update(overrides)
44
  return context
45
 
46
 
47
- async def _predict_upload(file: UploadFile) -> tuple[dict, bytes]:
48
- if not file.content_type or not file.content_type.startswith("image/"):
49
- raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
50
 
51
- data = await file.read()
52
- if not data:
53
- raise HTTPException(status_code=400, detail="Uploaded file is empty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  try:
56
  image = Image.open(BytesIO(data)).convert("RGB")
57
- except UnidentifiedImageError as exc:
58
  raise HTTPException(status_code=400, detail="Invalid image file.") from exc
59
 
60
- result = get_model_service().predict_image(image)
61
- result["filename"] = file.filename
62
- result["content_type"] = file.content_type
 
 
 
63
  return result, data
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  @app.get("/")
67
  def root():
68
  return {
@@ -89,14 +173,17 @@ def demo_page(request: Request):
89
 
90
 
91
  @app.post("/demo", response_class=HTMLResponse)
92
- async def demo_predict(request: Request, file: UploadFile = File(...)):
 
 
 
93
  try:
94
- result, data = await _predict_upload(file)
95
  except HTTPException as exc:
96
  return templates.TemplateResponse(
97
  request,
98
  "demo.html",
99
- _build_demo_context(error=exc.detail),
100
  status_code=exc.status_code,
101
  )
102
 
@@ -116,11 +203,13 @@ async def demo_predict(request: Request, file: UploadFile = File(...)):
116
  class_label=pred_label,
117
  confidence=f"{pred_conf * 100:.2f}%",
118
  acc=f"{pred_conf * 100:.2f}%",
 
119
  ),
120
  )
121
 
122
 
123
  @app.post("/predict")
124
- async def predict(file: UploadFile = File(...)):
125
- result, _ = await _predict_upload(file)
 
126
  return result
 
2
  from contextlib import asynccontextmanager
3
  from io import BytesIO
4
  from pathlib import Path
5
+ from urllib.parse import unquote_to_bytes
6
 
7
+ from fastapi import FastAPI, HTTPException, Request, UploadFile
8
  from fastapi.responses import HTMLResponse
9
  from fastapi.templating import Jinja2Templates
10
  from PIL import Image, UnidentifiedImageError
11
 
12
+ from model_service import MODEL_CONFIGS, get_model_config, get_model_service
13
 
14
  BASE_DIR = Path(__file__).resolve().parent
15
+ MAX_UPLOAD_SIZE = 10 * 1024 * 1024
16
  templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
17
  ACTIVE_MODEL_CONFIG = get_model_config()
18
 
 
41
  "confidence": "-",
42
  "acc": "-",
43
  "error": None,
44
+ "selected_model": ACTIVE_MODEL_CONFIG.name,
45
+ "model_options": [
46
+ {
47
+ "name": config.name,
48
+ "backend": config.backend,
49
+ "path": config.model_path.name,
50
+ }
51
+ for config in MODEL_CONFIGS.values()
52
+ ],
53
  }
54
  context.update(overrides)
55
  return context
56
 
57
 
58
+ def _parse_data_url(data_url: str) -> tuple[bytes, str]:
59
+ if not data_url.startswith("data:") or "," not in data_url:
60
+ raise HTTPException(status_code=400, detail="Invalid preview image data.")
61
 
62
+ header, encoded = data_url.split(",", 1)
63
+ content_type = header[5:].split(";")[0] or "image/png"
64
+ if not content_type.startswith("image/"):
65
+ raise HTTPException(status_code=400, detail="Preview data must be an image.")
66
+
67
+ if ";base64" in header:
68
+ try:
69
+ return base64.b64decode(encoded), content_type
70
+ except ValueError as exc:
71
+ raise HTTPException(status_code=400, detail="Invalid preview image data.") from exc
72
+
73
+ return unquote_to_bytes(encoded), content_type
74
+
75
+
76
+ async def _read_image_data(
77
+ file: UploadFile | None,
78
+ existing_image_data_url: str | None = None,
79
+ ) -> tuple[bytes, str, str | None]:
80
+ if file and file.filename:
81
+ if not file.content_type or not file.content_type.startswith("image/"):
82
+ raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
83
+
84
+ data = await file.read()
85
+ if not data:
86
+ raise HTTPException(status_code=400, detail="Uploaded file is empty.")
87
+ return data, file.content_type, file.filename
88
+
89
+ if existing_image_data_url:
90
+ data, content_type = _parse_data_url(existing_image_data_url)
91
+ if not data:
92
+ raise HTTPException(status_code=400, detail="Preview image is empty.")
93
+ return data, content_type, None
94
+
95
+ raise HTTPException(status_code=400, detail="Please upload an image first.")
96
+
97
+
98
+ async def _predict_upload(
99
+ file: UploadFile | None,
100
+ model_name: str | None = None,
101
+ existing_image_data_url: str | None = None,
102
+ ) -> tuple[dict, bytes]:
103
+ data, content_type, filename = await _read_image_data(file, existing_image_data_url)
104
 
105
  try:
106
  image = Image.open(BytesIO(data)).convert("RGB")
107
+ except (UnidentifiedImageError, OSError) as exc:
108
  raise HTTPException(status_code=400, detail="Invalid image file.") from exc
109
 
110
+ try:
111
+ result = get_model_service(model_name).predict_image(image)
112
+ except ValueError as exc:
113
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
114
+ result["filename"] = filename
115
+ result["content_type"] = content_type
116
  return result, data
117
 
118
 
119
+ async def _parse_demo_form(request: Request) -> tuple[UploadFile | None, str, str | None]:
120
+ try:
121
+ form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
122
+ except HTTPException:
123
+ raise
124
+ except Exception as exc:
125
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
126
+
127
+ file_value = form.get("file")
128
+ file = file_value if isinstance(file_value, UploadFile) else None
129
+ model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
130
+ existing_image_data_url = form.get("existing_image_data_url")
131
+ if existing_image_data_url is not None:
132
+ existing_image_data_url = str(existing_image_data_url)
133
+ return file, model_name, existing_image_data_url
134
+
135
+
136
+ async def _parse_predict_form(request: Request) -> tuple[UploadFile | None, str]:
137
+ try:
138
+ form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
139
+ except HTTPException:
140
+ raise
141
+ except Exception as exc:
142
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
143
+
144
+ file_value = form.get("file")
145
+ file = file_value if isinstance(file_value, UploadFile) else None
146
+ model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
147
+ return file, model_name
148
+
149
+
150
  @app.get("/")
151
  def root():
152
  return {
 
173
 
174
 
175
  @app.post("/demo", response_class=HTMLResponse)
176
+ async def demo_predict(
177
+ request: Request,
178
+ ):
179
+ file, model_name, existing_image_data_url = await _parse_demo_form(request)
180
  try:
181
+ result, data = await _predict_upload(file, model_name, existing_image_data_url)
182
  except HTTPException as exc:
183
  return templates.TemplateResponse(
184
  request,
185
  "demo.html",
186
+ _build_demo_context(error=exc.detail, selected_model=model_name),
187
  status_code=exc.status_code,
188
  )
189
 
 
203
  class_label=pred_label,
204
  confidence=f"{pred_conf * 100:.2f}%",
205
  acc=f"{pred_conf * 100:.2f}%",
206
+ selected_model=result["model_name"],
207
  ),
208
  )
209
 
210
 
211
  @app.post("/predict")
212
+ async def predict(request: Request):
213
+ file, model_name = await _parse_predict_form(request)
214
+ result, _ = await _predict_upload(file, model_name)
215
  return result
templates/demo.html CHANGED
@@ -51,24 +51,50 @@
51
  gap: 10px;
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  .upload-box {
55
  width: 100%;
56
- aspect-ratio: 1.25 / 1;
57
  border: 2px solid var(--line);
58
  background: var(--box-bg);
59
  color: #111;
60
- display: grid;
61
- place-items: center;
 
 
62
  text-align: center;
63
  cursor: pointer;
64
- overflow: hidden;
65
  position: relative;
66
  font-size: 22px;
67
  }
68
 
69
  .upload-box img {
70
- width: 100%;
71
- height: 100%;
 
 
 
 
72
  object-fit: contain;
73
  background: #fff;
74
  }
@@ -142,9 +168,17 @@
142
  }
143
 
144
  .upload-box {
 
145
  font-size: 18px;
146
  }
147
 
 
 
 
 
 
 
 
148
  .stats {
149
  font-size: 22px;
150
  }
@@ -153,8 +187,20 @@
153
  </head>
154
  <body>
155
  <form class="panel" action="/demo" method="post" enctype="multipart/form-data">
 
156
  <div class="top-row">
157
  <div class="upload-wrap">
 
 
 
 
 
 
 
 
 
 
 
158
  <label class="upload-box" for="fileInput" id="uploadBox">
159
  {% if image_data_url %}
160
  <img src="{{ image_data_url }}" alt="Input image preview" id="previewImage">
@@ -163,11 +209,12 @@
163
  <img src="" alt="Input image preview" id="previewImage" style="display:none;">
164
  {% endif %}
165
  </label>
166
- <input class="file-input" id="fileInput" type="file" name="file" accept="image/*" required>
167
  <div class="upload-hint">點擊圖片區塊選擇檔案</div>
168
  </div>
169
 
170
  <div class="stats">
 
171
  <div>Class: {{ class_label }}</div>
172
  <div>Confidence: {{ confidence }}</div>
173
  <div>Acc: {{ acc }}</div>
@@ -184,9 +231,11 @@
184
  </form>
185
 
186
  <script>
 
187
  const fileInput = document.getElementById("fileInput");
188
  const previewImage = document.getElementById("previewImage");
189
  const placeholderText = document.getElementById("placeholderText");
 
190
 
191
  fileInput.addEventListener("change", (event) => {
192
  const file = event.target.files?.[0];
@@ -197,9 +246,19 @@
197
  if (placeholderText) placeholderText.style.display = "none";
198
  previewImage.src = reader.result;
199
  previewImage.style.display = "block";
 
200
  };
201
  reader.readAsDataURL(file);
202
  });
 
 
 
 
 
 
 
 
 
203
  </script>
204
  </body>
205
  </html>
 
51
  gap: 10px;
52
  }
53
 
54
+ .field-group {
55
+ display: flex;
56
+ flex-direction: column;
57
+ gap: 8px;
58
+ }
59
+
60
+ .field-label {
61
+ font-size: 14px;
62
+ color: #555;
63
+ letter-spacing: 0.2px;
64
+ }
65
+
66
+ .model-select {
67
+ width: 100%;
68
+ padding: 11px 12px;
69
+ border: 2px solid var(--line);
70
+ background: #edf2f7;
71
+ color: #333;
72
+ font-size: 15px;
73
+ }
74
+
75
  .upload-box {
76
  width: 100%;
77
+ min-height: 280px;
78
  border: 2px solid var(--line);
79
  background: var(--box-bg);
80
  color: #111;
81
+ display: flex;
82
+ align-items: center;
83
+ justify-content: center;
84
+ padding: 14px;
85
  text-align: center;
86
  cursor: pointer;
 
87
  position: relative;
88
  font-size: 22px;
89
  }
90
 
91
  .upload-box img {
92
+ width: min(100%, 280px);
93
+ height: 280px;
94
+ min-width: 220px;
95
+ min-height: 220px;
96
+ max-width: 100%;
97
+ max-height: 320px;
98
  object-fit: contain;
99
  background: #fff;
100
  }
 
168
  }
169
 
170
  .upload-box {
171
+ min-height: 220px;
172
  font-size: 18px;
173
  }
174
 
175
+ .upload-box img {
176
+ width: min(100%, 220px);
177
+ height: 220px;
178
+ min-width: 160px;
179
+ min-height: 160px;
180
+ }
181
+
182
  .stats {
183
  font-size: 22px;
184
  }
 
187
  </head>
188
  <body>
189
  <form class="panel" action="/demo" method="post" enctype="multipart/form-data">
190
+ <input type="hidden" name="existing_image_data_url" id="existingImageDataUrl" value="{{ image_data_url or '' }}">
191
  <div class="top-row">
192
  <div class="upload-wrap">
193
+ <div class="field-group">
194
+ <label class="field-label" for="modelSelect">推論模型</label>
195
+ <select class="model-select" id="modelSelect" name="model_name">
196
+ {% for model in model_options %}
197
+ <option value="{{ model.name }}" {% if model.name == selected_model %}selected{% endif %}>
198
+ {{ model.name }} ({{ model.backend }})
199
+ </option>
200
+ {% endfor %}
201
+ </select>
202
+ </div>
203
+
204
  <label class="upload-box" for="fileInput" id="uploadBox">
205
  {% if image_data_url %}
206
  <img src="{{ image_data_url }}" alt="Input image preview" id="previewImage">
 
209
  <img src="" alt="Input image preview" id="previewImage" style="display:none;">
210
  {% endif %}
211
  </label>
212
+ <input class="file-input" id="fileInput" type="file" name="file" accept="image/*">
213
  <div class="upload-hint">點擊圖片區塊選擇檔案</div>
214
  </div>
215
 
216
  <div class="stats">
217
+ <div>Model: {{ selected_model }}</div>
218
  <div>Class: {{ class_label }}</div>
219
  <div>Confidence: {{ confidence }}</div>
220
  <div>Acc: {{ acc }}</div>
 
231
  </form>
232
 
233
  <script>
234
+ const form = document.querySelector(".panel");
235
  const fileInput = document.getElementById("fileInput");
236
  const previewImage = document.getElementById("previewImage");
237
  const placeholderText = document.getElementById("placeholderText");
238
+ const existingImageDataUrl = document.getElementById("existingImageDataUrl");
239
 
240
  fileInput.addEventListener("change", (event) => {
241
  const file = event.target.files?.[0];
 
246
  if (placeholderText) placeholderText.style.display = "none";
247
  previewImage.src = reader.result;
248
  previewImage.style.display = "block";
249
+ existingImageDataUrl.value = reader.result;
250
  };
251
  reader.readAsDataURL(file);
252
  });
253
+
254
+ form.addEventListener("submit", (event) => {
255
+ const hasNewFile = Boolean(fileInput.files?.length);
256
+ const hasExistingImage = Boolean(existingImageDataUrl.value);
257
+ if (!hasNewFile && !hasExistingImage) {
258
+ event.preventDefault();
259
+ alert("請先上傳圖片。");
260
+ }
261
+ });
262
  </script>
263
  </body>
264
  </html>