Eli181927 commited on
Commit
6d8cc91
·
verified ·
1 Parent(s): 4ba948e

Upload app.py

Browse files
Files changed (1) hide show
  1. 2.CNN/app.py +247 -52
2.CNN/app.py CHANGED
@@ -187,42 +187,244 @@ def _auto_balance_stroke(arr: np.ndarray, *, target_mass_fraction: float, clamp:
187
  return adjusted, scale, new_mass_fraction
188
 
189
 
190
- def compose_dual_canvas(left_input, right_input):
191
- left_img = extract_canvas_array(left_input)
192
- right_img = extract_canvas_array(right_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- if left_img is None and right_img is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  return None
 
 
 
196
 
197
- if left_img is None:
198
- if right_img is None:
199
- return None
200
- base_size = right_img.size
201
- left_img = Image.new("L", base_size, color=255)
202
- if right_img is None:
203
- base_size = left_img.size
204
- right_img = Image.new("L", base_size, color=255)
205
-
206
- left_img = left_img.convert("L")
207
- right_img = right_img.convert("L")
208
-
209
- if left_img.height != right_img.height:
210
- target_height = min(left_img.height, right_img.height)
211
- left_img = left_img.resize(
212
- (left_img.width, target_height), Image.Resampling.LANCZOS
213
- )
214
- right_img = right_img.resize(
215
- (right_img.width, target_height), Image.Resampling.LANCZOS
216
- )
217
 
218
- combined = Image.new(
219
- "L",
220
- (left_img.width + right_img.width, left_img.height),
221
- color=255,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
- combined.paste(left_img, (0, 0))
224
- combined.paste(right_img, (left_img.width, 0))
225
- return combined
 
 
 
 
 
 
 
226
 
227
 
228
  def preprocess_image(img_input, stroke_scale: float = 1.0):
@@ -523,25 +725,26 @@ def enrich_diagnostics(stats, probs):
523
  return stats
524
 
525
 
526
- def predict_number(left_canvas, right_canvas, stroke_scale):
527
  ensure_model_loaded()
528
- combined_canvas = compose_dual_canvas(left_canvas, right_canvas)
529
- if combined_canvas is None:
530
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
531
  empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
532
  empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
533
- diagnostics = {"warnings": ["Draw both digits to see diagnostics."]}
534
  return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
535
 
536
- result = preprocess_image(
537
- combined_canvas,
538
  stroke_scale=stroke_scale,
 
539
  )
540
  if result is None:
541
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
542
  empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
543
  empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
544
- diagnostics = {"warnings": ["Draw a number to see diagnostics."]}
545
  return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
546
 
547
  standardized_variants, preview, mean_diff, diagnostics = result
@@ -567,15 +770,13 @@ with gr.Blocks() as demo:
567
  gr.Markdown(
568
  """
569
  # Elliot's MNIST-100 Classifier
570
- Draw a two-digit number (00-99). Use the left canvas for the tens digit and the right canvas for the ones digit. The model will predict the number, show the top class probabilities, and display diagnostics for the processed input.
571
  """
572
  )
573
 
574
  with gr.Row():
575
  with gr.Column(scale=1):
576
- with gr.Row():
577
- left_canvas = gr.Sketchpad(label="Tens Digit")
578
- right_canvas = gr.Sketchpad(label="Ones Digit")
579
  stroke_slider = gr.Slider(
580
  minimum=0.3,
581
  maximum=1.2,
@@ -598,8 +799,7 @@ with gr.Blocks() as demo:
598
  predict_btn = gr.Button("Predict", variant="primary")
599
  clear_btn = gr.ClearButton(
600
  [
601
- left_canvas,
602
- right_canvas,
603
  stroke_slider,
604
  pred_box,
605
  prob_table,
@@ -611,19 +811,14 @@ with gr.Blocks() as demo:
611
 
612
  predict_btn.click(
613
  fn=predict_number,
614
- inputs=[left_canvas, right_canvas, stroke_slider],
615
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
616
  )
617
  # On Spaces, avoid per-stroke inference to prevent event floods
618
  if not IS_SPACE:
619
- left_canvas.change(
620
- fn=predict_number,
621
- inputs=[left_canvas, right_canvas, stroke_slider],
622
- outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
623
- )
624
- right_canvas.change(
625
  fn=predict_number,
626
- inputs=[left_canvas, right_canvas, stroke_slider],
627
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
628
  )
629
 
 
187
  return adjusted, scale, new_mass_fraction
188
 
189
 
190
+ def _valley_split(mask: np.ndarray) -> int | None:
191
+ # Find a vertical seam (column) with minimal foreground to split two digits
192
+ H, W = mask.shape
193
+ if W < 8:
194
+ return None
195
+ col_sums = mask.sum(axis=0)
196
+ start = max(1, int(W * 0.25))
197
+ end = min(W - 1, int(W * 0.75))
198
+ if end <= start:
199
+ start, end = 1, W - 1
200
+ idx = int(np.argmin(col_sums[start:end])) + start
201
+ left_mass = int(col_sums[:idx].sum())
202
+ right_mass = int(col_sums[idx:].sum())
203
+ if left_mass > 50 and right_mass > 50:
204
+ return idx
205
+ return None
206
+
207
+
208
+ def _connected_components(mask: np.ndarray):
209
+ H, W = mask.shape
210
+ visited = np.zeros_like(mask, dtype=bool)
211
+ comps = []
212
+ for y in range(H):
213
+ row = mask[y]
214
+ for x in range(W):
215
+ if row[x] and not visited[y, x]:
216
+ stack = [(y, x)]
217
+ visited[y, x] = True
218
+ ys, xs = [], []
219
+ while stack:
220
+ cy, cx = stack.pop()
221
+ ys.append(cy)
222
+ xs.append(cx)
223
+ # 4-connectivity
224
+ if cy > 0 and mask[cy - 1, cx] and not visited[cy - 1, cx]:
225
+ visited[cy - 1, cx] = True
226
+ stack.append((cy - 1, cx))
227
+ if cy + 1 < H and mask[cy + 1, cx] and not visited[cy + 1, cx]:
228
+ visited[cy + 1, cx] = True
229
+ stack.append((cy + 1, cx))
230
+ if cx > 0 and mask[cy, cx - 1] and not visited[cy, cx - 1]:
231
+ visited[cy, cx - 1] = True
232
+ stack.append((cy, cx - 1))
233
+ if cx + 1 < W and mask[cy, cx + 1] and not visited[cy, cx + 1]:
234
+ visited[cy, cx + 1] = True
235
+ stack.append((cy, cx + 1))
236
+ y1, y2 = min(ys), max(ys) + 1
237
+ x1, x2 = min(xs), max(xs) + 1
238
+ comps.append({"bbox": (y1, y2, x1, x2), "size": len(ys)})
239
+ return comps
240
+
241
+
242
+ def canonicalize_digit_28x28(arr: np.ndarray) -> np.ndarray:
243
+ # Input arr: float32 in [0,1], arbitrary HxW; output: 28x28 centered tile
244
+ if arr.size == 0:
245
+ return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
246
+ thr = arr > 0.05
247
+ if not thr.any():
248
+ return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
249
+ ys, xs = np.where(thr)
250
+ y1, y2 = ys.min(), ys.max() + 1
251
+ x1, x2 = xs.min(), xs.max() + 1
252
+ # small padding
253
+ pad = 2
254
+ y1 = max(0, y1 - pad)
255
+ x1 = max(0, x1 - pad)
256
+ y2 = min(arr.shape[0], y2 + pad)
257
+ x2 = min(arr.shape[1], x2 + pad)
258
+ crop = arr[y1:y2, x1:x2]
259
+ h, w = crop.shape
260
+ if h == 0 or w == 0:
261
+ return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
262
+ # resize shorter side to 20
263
+ if h >= w:
264
+ new_h = 20
265
+ new_w = max(1, int(round(w * (20.0 / h))))
266
+ else:
267
+ new_w = 20
268
+ new_h = max(1, int(round(h * (20.0 / w))))
269
+ small = Image.fromarray((crop * 255.0).astype(np.uint8)).resize(
270
+ (new_w, new_h), Image.Resampling.LANCZOS
271
+ )
272
+ tile = Image.new("L", (TARGET_HEIGHT, TARGET_HEIGHT), color=0)
273
+ # paste centered
274
+ top = (TARGET_HEIGHT - new_h) // 2
275
+ left = (TARGET_HEIGHT - new_w) // 2
276
+ tile.paste(small, (left, top))
277
+ tile_arr = np.array(tile, dtype=np.float32) / 255.0
278
+ # center-of-mass shift to exact center
279
+ mass = tile_arr
280
+ tot = float(mass.sum())
281
+ if tot > 1e-6:
282
+ gy, gx = np.indices(mass.shape)
283
+ cy = float((gy * mass).sum() / tot)
284
+ cx = float((gx * mass).sum() / tot)
285
+ ideal = (TARGET_HEIGHT - 1) / 2.0
286
+ dy = int(np.clip(round(ideal - cy), -2, 2))
287
+ dx = int(np.clip(round(ideal - cx), -2, 2))
288
+ if dy != 0 or dx != 0:
289
+ tile_arr = shift_with_zero_pad(tile_arr, dy, dx)
290
+ return tile_arr.astype(np.float32, copy=False)
291
+
292
 
293
+ def compose_from_single_canvas(img_input):
294
+ img = extract_canvas_array(img_input)
295
+ if img is None:
296
+ return None, {"warnings": ["No image provided."]}
297
+ try:
298
+ bands = img.getbands()
299
+ except Exception:
300
+ bands = ()
301
+ if "A" in bands:
302
+ rgba = img.convert("RGBA")
303
+ white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255))
304
+ img = Image.alpha_composite(white_bg, rgba).convert("RGB")
305
+ gray = img.convert("L")
306
+ inv = ImageOps.invert(gray)
307
+ arr_u8 = np.array(inv, dtype=np.uint8)
308
+ mask = arr_u8 > 10
309
+ if not mask.any():
310
+ return None, {"warnings": ["Empty drawing detected."]}
311
+
312
+ # Global bbox trim for speed
313
+ ys, xs = np.where(mask)
314
+ y1, y2 = ys.min(), ys.max() + 1
315
+ x1, x2 = xs.min(), xs.max() + 1
316
+ pad = 4
317
+ y1 = max(0, y1 - pad)
318
+ x1 = max(0, x1 - pad)
319
+ y2 = min(arr_u8.shape[0], y2 + pad)
320
+ x2 = min(arr_u8.shape[1], x2 + pad)
321
+ arr_u8 = arr_u8[y1:y2, x1:x2]
322
+ mask = mask[y1:y2, x1:x2]
323
+
324
+ method = "valley"
325
+ split = _valley_split(mask)
326
+ left_arr = right_arr = None
327
+ if split is not None:
328
+ left_area = arr_u8[:, :split]
329
+ right_area = arr_u8[:, split:]
330
+ if (left_area > 10).any():
331
+ l_ys, l_xs = np.where(left_area > 10)
332
+ ly1, ly2 = l_ys.min(), l_ys.max() + 1
333
+ lx1, lx2 = l_xs.min(), l_xs.max() + 1
334
+ left_arr = left_area[ly1:ly2, lx1:lx2]
335
+ if (right_area > 10).any():
336
+ r_ys, r_xs = np.where(right_area > 10)
337
+ ry1, ry2 = r_ys.min(), r_ys.max() + 1
338
+ rx1, rx2 = r_xs.min(), r_xs.max() + 1
339
+ right_arr = right_area[ry1:ry2, rx1:rx2]
340
+ else:
341
+ method = "components"
342
+ comps = _connected_components(mask)
343
+ if len(comps) >= 2:
344
+ comps.sort(key=lambda c: c["size"], reverse=True)
345
+ a, b = comps[0], comps[1]
346
+ # sort left/right by x1
347
+ if a["bbox"][2] <= b["bbox"][2]:
348
+ left_bbox, right_bbox = a["bbox"], b["bbox"]
349
+ else:
350
+ left_bbox, right_bbox = b["bbox"], a["bbox"]
351
+ ly1, ly2, lx1, lx2 = left_bbox
352
+ ry1, ry2, rx1, rx2 = right_bbox
353
+ left_arr = arr_u8[ly1:ly2, lx1:lx2]
354
+ right_arr = arr_u8[ry1:ry2, rx1:rx2]
355
+ else:
356
+ # Fallback: split the single bbox in half
357
+ method = "fallback_center_split"
358
+ W = arr_u8.shape[1]
359
+ split = W // 2
360
+ left_arr = arr_u8[:, :split]
361
+ right_arr = arr_u8[:, split:]
362
+
363
+ # Convert to float and canonicalize per digit
364
+ left_tile = canonicalize_digit_28x28((left_arr.astype(np.float32) / 255.0) if left_arr is not None else np.zeros((1, 1), dtype=np.float32))
365
+ right_tile = canonicalize_digit_28x28((right_arr.astype(np.float32) / 255.0) if right_arr is not None else np.zeros((1, 1), dtype=np.float32))
366
+ composed = np.concatenate([left_tile, right_tile], axis=1)
367
+ diag = {
368
+ "segmentation": {
369
+ "method": method,
370
+ "canvas_crop": {"top": int(y1), "bottom": int(y2), "left": int(x1), "right": int(x2)},
371
+ }
372
+ }
373
+ return composed.astype(np.float32, copy=False), diag
374
+
375
+
376
+ def preprocess_composed_28x56(arr_28x56: np.ndarray, stroke_scale: float = 1.0, *, extra_diag: dict | None = None):
377
+ ensure_model_loaded()
378
+ if arr_28x56 is None:
379
  return None
380
+ arr_resized = np.clip(arr_28x56.astype(np.float32), 0.0, 1.0)
381
+ mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH)
382
+ std_safe = np.maximum(std, STD_FLOOR)
383
 
384
+ stroke_scale = float(stroke_scale)
385
+ stroke_scale = max(0.3, min(stroke_scale, 1.5))
386
+ arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
+ auto_balance_scale = 1.0
389
+ pre_balance_mass_fraction = float(arr_resized.mean())
390
+ target_mass = float(mean.mean())
391
+ arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
392
+ arr_resized,
393
+ target_mass_fraction=target_mass,
394
+ clamp=(0.6, 1.6),
395
+ )
396
+
397
+ # We already centered per 28x28 tile; skip whole-image recentering here
398
+ arr_centered = arr_resized
399
+
400
+ augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)]
401
+ augmented_standardized = []
402
+ for arr in augmented_arrays:
403
+ z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
404
+ z = np.clip(z, -8.0, 8.0)
405
+ augmented_standardized.append(z.astype(np.float32, copy=False))
406
+
407
+ mean_diff = np.abs(arr_centered - mean_image)
408
+ mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
409
+
410
+ diagnostics = compute_diagnostics(
411
+ arr_centered,
412
+ None,
413
+ arr_centered.shape,
414
+ mean_image,
415
+ augmented_standardized[0],
416
+ std_safe,
417
  )
418
+ diagnostics["applied_auto_balance"] = {
419
+ "enabled": True,
420
+ "scale": float(auto_balance_scale),
421
+ "mass_fraction_after": float(balanced_mass_fraction),
422
+ "mass_fraction_before": float(pre_balance_mass_fraction),
423
+ "target_mass_fraction": float(target_mass),
424
+ }
425
+ if extra_diag:
426
+ diagnostics.update(extra_diag)
427
+ return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
428
 
429
 
430
  def preprocess_image(img_input, stroke_scale: float = 1.0):
 
725
  return stats
726
 
727
 
728
+ def predict_number(main_canvas, stroke_scale):
729
  ensure_model_loaded()
730
+ composed, seg_diag = compose_from_single_canvas(main_canvas)
731
+ if composed is None:
732
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
733
  empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
734
  empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
735
+ diagnostics = {"warnings": ["Draw two digits to see diagnostics."]}
736
  return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
737
 
738
+ result = preprocess_composed_28x56(
739
+ composed,
740
  stroke_scale=stroke_scale,
741
+ extra_diag=seg_diag,
742
  )
743
  if result is None:
744
  blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
745
  empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
746
  empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
747
+ diagnostics = {"warnings": ["Draw two digits to see diagnostics."]}
748
  return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
749
 
750
  standardized_variants, preview, mean_diff, diagnostics = result
 
770
  gr.Markdown(
771
  """
772
  # Elliot's MNIST-100 Classifier
773
+ Draw a two-digit number (00-99) on the single canvas. The app automatically segments, centers, and scales each digit to match the training layout (28×28 per digit), then predicts and shows diagnostics.
774
  """
775
  )
776
 
777
  with gr.Row():
778
  with gr.Column(scale=1):
779
+ main_canvas = gr.Sketchpad(label="Draw Two Digits (00–99)")
 
 
780
  stroke_slider = gr.Slider(
781
  minimum=0.3,
782
  maximum=1.2,
 
799
  predict_btn = gr.Button("Predict", variant="primary")
800
  clear_btn = gr.ClearButton(
801
  [
802
+ main_canvas,
 
803
  stroke_slider,
804
  pred_box,
805
  prob_table,
 
811
 
812
  predict_btn.click(
813
  fn=predict_number,
814
+ inputs=[main_canvas, stroke_slider],
815
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
816
  )
817
  # On Spaces, avoid per-stroke inference to prevent event floods
818
  if not IS_SPACE:
819
+ main_canvas.change(
 
 
 
 
 
820
  fn=predict_number,
821
+ inputs=[main_canvas, stroke_slider],
822
  outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
823
  )
824