Eli181927 commited on
Commit
d9c92a8
·
verified ·
1 Parent(s): a6e7936

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +617 -0
app.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ import gradio.routes as gr_routes
5
+ from PIL import Image, ImageOps
6
+ from pathlib import Path
7
+ import importlib.util
8
+ import json
9
+
10
+
11
+ OUTPUT_CLASSES = 100
12
+ TARGET_HEIGHT, TARGET_WIDTH = 28, 56
13
+ STD_FLOOR = 1e-8
14
+ METRIC_TARGETS = {
15
+ "mass_fraction": (0.08, 0.35),
16
+ "stroke_density": (0.12, 0.65),
17
+ "center_offset": (0.0, 8.0),
18
+ "mean_abs_z_score": (0.0, 2.5),
19
+ "max_abs_z_score": (0.0, 6.0),
20
+ "std_abs_z_score": (0.0, 1.5),
21
+ }
22
+
23
+
24
+ def _load_training_module():
25
+ module_path = Path(__file__).resolve().parent / "training-100.py"
26
+ spec = importlib.util.spec_from_file_location("mnist100_training", module_path)
27
+ module = importlib.util.module_from_spec(spec)
28
+ assert spec.loader is not None
29
+ spec.loader.exec_module(module)
30
+ return module
31
+
32
+
33
+ training_mod = _load_training_module()
34
+ forward_prop = training_mod.forward_prop
35
+ get_predictions = training_mod.get_predictions
36
+ softmax = training_mod.softmax
37
+
38
+
39
+ def _metric_status(name, value):
40
+ target = METRIC_TARGETS.get(name)
41
+ status = "not_tracked"
42
+ target_dict = None
43
+ if target is not None:
44
+ low, high = target
45
+ target_dict = {"min": low, "max": high}
46
+ if value is None or np.isnan(value):
47
+ status = "invalid"
48
+ elif low <= value <= high:
49
+ status = "ok"
50
+ else:
51
+ status = "out_of_range"
52
+ return status, target_dict
53
+
54
+
55
+ def load_trained_artifacts(model_path=None):
56
+ base_dir = Path(__file__).resolve().parent
57
+ if model_path is None:
58
+ resolved_path = base_dir / "archive" / "trained_model_mnist100.npz"
59
+ else:
60
+ candidate = Path(model_path)
61
+ resolved_path = candidate if candidate.is_absolute() else base_dir / candidate
62
+ if not resolved_path.exists():
63
+ raise RuntimeError(
64
+ f"Model file '{resolved_path}' not found. Train the MNIST-100 model first by running 'python training-100.py'."
65
+ )
66
+ loaded = np.load(resolved_path)
67
+ params = {key: loaded[key] for key in loaded.files if key not in {"mean", "std"}}
68
+ mean = loaded["mean"]
69
+ std = loaded["std"]
70
+ return params, mean, std
71
+
72
+
73
+ params, mean, std = None, None, None
74
+
75
+
76
+ def ensure_model_loaded():
77
+ global params, mean, std
78
+ if params is None or mean is None or std is None:
79
+ params, mean, std = load_trained_artifacts()
80
+
81
+
82
+ def extract_canvas_array(img_input):
83
+ if img_input is None:
84
+ return None
85
+
86
+ if isinstance(img_input, dict):
87
+ for key in ("image", "composite", "background", "value"):
88
+ payload = img_input.get(key)
89
+ if payload is not None:
90
+ img_input = payload
91
+ break
92
+ else:
93
+ return None
94
+
95
+ if isinstance(img_input, Image.Image):
96
+ return img_input
97
+
98
+ if isinstance(img_input, np.ndarray):
99
+ arr_in = img_input
100
+ if arr_in.dtype != np.uint8:
101
+ max_val = float(arr_in.max()) if arr_in.size else 1.0
102
+ if max_val <= 1.5:
103
+ arr_in = (arr_in * 255.0).clip(0, 255).astype(np.uint8)
104
+ else:
105
+ arr_in = np.clip(arr_in, 0, 255).astype(np.uint8)
106
+ if arr_in.ndim == 3 and arr_in.shape[2] == 4:
107
+ return Image.fromarray(arr_in, mode="RGBA")
108
+ return Image.fromarray(arr_in)
109
+
110
+ return None
111
+
112
+
113
+ def shift_with_zero_pad(arr, shift_y=0, shift_x=0):
114
+ if shift_y == 0 and shift_x == 0:
115
+ return arr
116
+ rolled = np.roll(arr, shift=shift_y, axis=0)
117
+ rolled = np.roll(rolled, shift=shift_x, axis=1)
118
+ out = rolled.copy()
119
+ if shift_y > 0:
120
+ out[:shift_y, :] = 0.0
121
+ elif shift_y < 0:
122
+ out[shift_y:, :] = 0.0
123
+ if shift_x > 0:
124
+ out[:, :shift_x] = 0.0
125
+ elif shift_x < 0:
126
+ out[:, shift_x:] = 0.0
127
+ return out
128
+
129
+
130
+ def dilate_binary_like(arr, radius=1):
131
+ pad = radius
132
+ padded = np.pad(arr, pad, mode="constant", constant_values=0.0)
133
+ out = np.zeros_like(arr)
134
+ for i in range(arr.shape[0]):
135
+ for j in range(arr.shape[1]):
136
+ window = padded[i : i + 2 * pad + 1, j : j + 2 * pad + 1]
137
+ out[i, j] = window.max()
138
+ return out
139
+
140
+
141
+ def erode_binary_like(arr, radius=1):
142
+ pad = radius
143
+ padded = np.pad(arr, pad, mode="constant", constant_values=1.0)
144
+ out = np.zeros_like(arr)
145
+ for i in range(arr.shape[0]):
146
+ for j in range(arr.shape[1]):
147
+ window = padded[i : i + 2 * pad + 1, j : j + 2 * pad + 1]
148
+ out[i, j] = window.min()
149
+ return out
150
+
151
+
152
+ def generate_inference_variants(arr):
153
+ variants = []
154
+ # slight shifts
155
+ for dy in (-1, 0, 1):
156
+ for dx in (-1, 0, 1):
157
+ if dy == 0 and dx == 0:
158
+ continue
159
+ variants.append(shift_with_zero_pad(arr, dy, dx))
160
+ # dilation/erosion to handle thin or thick strokes
161
+ variants.append(dilate_binary_like(arr, radius=1))
162
+ variants.append(erode_binary_like(arr, radius=1))
163
+ return variants
164
+
165
+
166
+ def _auto_balance_stroke(arr: np.ndarray, *, target_mass_fraction: float, clamp: tuple[float, float]):
167
+ mass_fraction = float(arr.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
168
+ if mass_fraction <= 1e-6:
169
+ return arr, 1.0, mass_fraction
170
+ scale = np.sqrt(target_mass_fraction / mass_fraction)
171
+ min_scale, max_scale = clamp
172
+ scale = float(np.clip(scale, min_scale, max_scale))
173
+ adjusted = np.clip(arr * scale, 0.0, 1.0)
174
+ new_mass_fraction = float(adjusted.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
175
+ return adjusted, scale, new_mass_fraction
176
+
177
+
178
+ def compose_dual_canvas(left_input, right_input):
179
+ left_img = extract_canvas_array(left_input)
180
+ right_img = extract_canvas_array(right_input)
181
+
182
+ if left_img is None and right_img is None:
183
+ return None
184
+
185
+ if left_img is None:
186
+ if right_img is None:
187
+ return None
188
+ base_size = right_img.size
189
+ left_img = Image.new("L", base_size, color=255)
190
+ if right_img is None:
191
+ base_size = left_img.size
192
+ right_img = Image.new("L", base_size, color=255)
193
+
194
+ left_img = left_img.convert("L")
195
+ right_img = right_img.convert("L")
196
+
197
+ if left_img.height != right_img.height:
198
+ target_height = min(left_img.height, right_img.height)
199
+ left_img = left_img.resize(
200
+ (left_img.width, target_height), Image.Resampling.LANCZOS
201
+ )
202
+ right_img = right_img.resize(
203
+ (right_img.width, target_height), Image.Resampling.LANCZOS
204
+ )
205
+
206
+ combined = Image.new(
207
+ "L",
208
+ (left_img.width + right_img.width, left_img.height),
209
+ color=255,
210
+ )
211
+ combined.paste(left_img, (0, 0))
212
+ combined.paste(right_img, (left_img.width, 0))
213
+ return combined
214
+
215
+
216
+ def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool = True):
217
+ ensure_model_loaded()
218
+ img = extract_canvas_array(img_input)
219
+ if img is None:
220
+ return None
221
+
222
+ try:
223
+ bands = img.getbands()
224
+ except Exception:
225
+ bands = ()
226
+ if "A" in bands:
227
+ rgba = img.convert("RGBA")
228
+ white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255))
229
+ img = Image.alpha_composite(white_bg, rgba).convert("RGB")
230
+
231
+ img = img.convert("L")
232
+ img = ImageOps.invert(img)
233
+
234
+ arr_u8 = np.array(img, dtype=np.uint8)
235
+ original_canvas_shape = arr_u8.shape
236
+
237
+ coords = np.column_stack(np.where(arr_u8 > 10))
238
+ bbox = None
239
+ if coords.size > 0:
240
+ y_min, x_min = coords.min(axis=0)
241
+ y_max, x_max = coords.max(axis=0) + 1
242
+ pad = 4
243
+ y_min = max(0, y_min - pad)
244
+ x_min = max(0, x_min - pad)
245
+ y_max = min(arr_u8.shape[0], y_max + pad)
246
+ x_max = min(arr_u8.shape[1], x_max + pad)
247
+ bbox = (int(y_min), int(y_max), int(x_min), int(x_max))
248
+ arr_u8 = arr_u8[y_min:y_max, x_min:x_max]
249
+
250
+ if arr_u8.size == 0:
251
+ return None
252
+
253
+ h, w = arr_u8.shape
254
+ target_ratio = TARGET_WIDTH / TARGET_HEIGHT
255
+ if h == 0 or w == 0:
256
+ return None
257
+ current_ratio = w / h if h else target_ratio
258
+
259
+ if current_ratio > target_ratio:
260
+ new_height = int(round(w / target_ratio))
261
+ pad_total = max(new_height - h, 0)
262
+ pad_top = pad_total // 2
263
+ pad_bottom = pad_total - pad_top
264
+ pad_left = pad_right = 0
265
+ else:
266
+ new_width = int(round(h * target_ratio))
267
+ pad_total = max(new_width - w, 0)
268
+ pad_left = pad_total // 2
269
+ pad_right = pad_total - pad_left
270
+ pad_top = pad_bottom = 0
271
+
272
+ arr_padded = np.pad(
273
+ arr_u8,
274
+ ((pad_top, pad_bottom), (pad_left, pad_right)),
275
+ mode="constant",
276
+ constant_values=0,
277
+ )
278
+
279
+ resized = Image.fromarray(arr_padded).resize(
280
+ (TARGET_WIDTH, TARGET_HEIGHT), Image.Resampling.LANCZOS
281
+ )
282
+ arr_resized = np.array(resized, dtype=np.float32) / 255.0
283
+
284
+ mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH)
285
+ std_safe = np.maximum(std, STD_FLOOR)
286
+
287
+ stroke_scale = float(stroke_scale)
288
+ stroke_scale = max(0.3, min(stroke_scale, 1.5))
289
+ arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
290
+
291
+ auto_balance_scale = 1.0
292
+ balanced_mass_fraction = float(arr_resized.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
293
+ if auto_balance:
294
+ target_mass = sum(METRIC_TARGETS["mass_fraction"]) / 2.0
295
+ arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
296
+ arr_resized,
297
+ target_mass_fraction=target_mass,
298
+ clamp=(0.7, 1.4),
299
+ )
300
+
301
+ augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized)]
302
+ augmented_standardized = [
303
+ (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
304
+ for arr in augmented_arrays
305
+ ]
306
+
307
+ mean_diff = np.abs(arr_resized - mean_image)
308
+ mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
309
+
310
+ diagnostics = compute_diagnostics(
311
+ arr_resized,
312
+ bbox,
313
+ original_canvas_shape,
314
+ mean_image,
315
+ augmented_standardized[0],
316
+ std_safe,
317
+ )
318
+ diagnostics["applied_auto_balance"] = {
319
+ "enabled": bool(auto_balance),
320
+ "scale": float(auto_balance_scale),
321
+ "mass_fraction_after": float(balanced_mass_fraction),
322
+ }
323
+
324
+ return augmented_standardized, arr_resized, mean_diff_uint8, diagnostics
325
+
326
+
327
+ def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
328
+ mass = arr_float
329
+ total_intensity = float(mass.sum())
330
+ mass_threshold = mass > 0.05
331
+ if mass_threshold.any():
332
+ ys, xs = np.where(mass_threshold)
333
+ bbox_est = (int(ys.min()), int(ys.max()) + 1, int(xs.min()), int(xs.max()) + 1)
334
+ else:
335
+ bbox_est = None
336
+
337
+ cy = cx = None
338
+ if total_intensity > 1e-6:
339
+ grid_y, grid_x = np.indices(mass.shape)
340
+ weighted_sum = mass.sum()
341
+ cy = float((grid_y * mass).sum() / weighted_sum)
342
+ cx = float((grid_x * mass).sum() / weighted_sum)
343
+
344
+ bbox_use = bbox_est or bbox
345
+ if bbox_use:
346
+ top, bottom, left, right = bbox_use
347
+ height = bottom - top
348
+ width = right - left
349
+ bbox_area = height * width
350
+ bbox_metrics = {
351
+ "top": top,
352
+ "bottom": bottom,
353
+ "left": left,
354
+ "right": right,
355
+ "height": height,
356
+ "width": width,
357
+ "aspect_ratio": float(width / height) if height else None,
358
+ "area": bbox_area,
359
+ "area_ratio": float(bbox_area / (TARGET_HEIGHT * TARGET_WIDTH)) if bbox_area else 0.0,
360
+ }
361
+ else:
362
+ bbox_metrics = {
363
+ "top": None,
364
+ "bottom": None,
365
+ "left": None,
366
+ "right": None,
367
+ "height": 0,
368
+ "width": 0,
369
+ "aspect_ratio": None,
370
+ "area": 0,
371
+ "area_ratio": 0.0,
372
+ }
373
+
374
+ density = 0.0
375
+ bbox_area = bbox_metrics.get("area", 0)
376
+ if bbox_area:
377
+ density = float(total_intensity / bbox_area)
378
+
379
+ center_offset = None
380
+ if cy is not None and cx is not None:
381
+ ideal_cy = (TARGET_HEIGHT - 1) / 2.0
382
+ ideal_cx = (TARGET_WIDTH - 1) / 2.0
383
+ center_offset = float(np.sqrt((cy - ideal_cy) ** 2 + (cx - ideal_cx) ** 2))
384
+
385
+ standardized_flat = standardized.flatten()
386
+ mean_flat = mean_image.flatten()
387
+ arr_flat = arr_float.flatten()
388
+ std_flat = std_safe.flatten()
389
+
390
+ norm_input = np.linalg.norm(arr_flat)
391
+ norm_mean = np.linalg.norm(mean_flat)
392
+ cosine_similarity = None
393
+ if norm_input > 0.0 and norm_mean > 0.0:
394
+ cosine_similarity = float(np.dot(arr_flat, mean_flat) / (norm_input * norm_mean))
395
+
396
+ mean_abs_z = float(np.mean(np.abs(standardized_flat)))
397
+ max_abs_z = float(np.max(np.abs(standardized_flat)))
398
+ std_of_z = float(np.std(standardized_flat))
399
+
400
+ low_var_mask = std_flat <= STD_FLOOR + 1e-12
401
+ activated_low_var = int(np.count_nonzero(low_var_mask & (np.abs(arr_flat - mean_flat) > 1e-3)))
402
+
403
+ stats = {
404
+ "total_intensity": total_intensity,
405
+ "mass_fraction": float(total_intensity / (TARGET_HEIGHT * TARGET_WIDTH)),
406
+ "center_of_mass": {"row": cy, "col": cx},
407
+ "center_offset": center_offset,
408
+ "bbox": bbox_metrics,
409
+ "original_canvas_shape": original_shape,
410
+ "stroke_density": density,
411
+ "warnings": [],
412
+ "mean_intensity": float(arr_float.mean()),
413
+ "pixel_intensity_range": {
414
+ "min": float(arr_float.min()),
415
+ "max": float(arr_float.max()),
416
+ },
417
+ "cosine_similarity_vs_mean": cosine_similarity,
418
+ "mean_abs_z_score": mean_abs_z,
419
+ "max_abs_z_score": max_abs_z,
420
+ "std_abs_z_score": std_of_z,
421
+ "low_variance_pixels_triggered": activated_low_var,
422
+ "low_variance_threshold": STD_FLOOR,
423
+ "low_variance_pixels_fraction": float(activated_low_var / max(1, int(low_var_mask.sum()))),
424
+ }
425
+
426
+ if mean_image is not None:
427
+ stats["distance_from_mean"] = float(np.linalg.norm(arr_float - mean_image))
428
+
429
+ metric_checks = {}
430
+ for metric_name in (
431
+ "mass_fraction",
432
+ "stroke_density",
433
+ "center_offset",
434
+ "mean_abs_z_score",
435
+ "max_abs_z_score",
436
+ "std_abs_z_score",
437
+ ):
438
+ value = stats.get(metric_name)
439
+ if value is not None:
440
+ value = float(value)
441
+ status, target_dict = _metric_status(metric_name, value)
442
+ entry = {"value": value, "status": status}
443
+ if target_dict is not None:
444
+ entry["target"] = target_dict
445
+ metric_checks[metric_name] = entry
446
+ stats["metric_checks"] = metric_checks
447
+
448
+ return stats
449
+
450
+
451
+ def enrich_diagnostics(stats, probs):
452
+ warnings = []
453
+ bbox = stats.get("bbox", {})
454
+ metric_checks = stats.get("metric_checks", {})
455
+
456
+ for name, info in metric_checks.items():
457
+ if info.get("status") == "out_of_range":
458
+ target = info.get("target")
459
+ value = info.get("value")
460
+ value_str = "None" if value is None else f"{value:.4f}"
461
+ if target is not None:
462
+ warnings.append(
463
+ f"{name}: value={value_str}, target=[{target['min']:.4f},{target['max']:.4f}]"
464
+ )
465
+ else:
466
+ warnings.append(f"{name}: value={value_str}")
467
+
468
+ aspect_ratio = bbox.get("aspect_ratio")
469
+ if aspect_ratio is not None and (aspect_ratio < 1.0 or aspect_ratio > 3.5):
470
+ warnings.append(f"aspect_ratio: value={aspect_ratio:.4f}, expected≈[1.00,3.50]")
471
+
472
+ confidences = np.sort(probs.flatten())[::-1]
473
+ if confidences.size >= 2:
474
+ margin = confidences[0] - confidences[1]
475
+ stats_margin = {
476
+ "value": float(margin),
477
+ "status": "ok" if margin >= 0.05 else "low_margin",
478
+ "target": {"min": 0.05, "max": 1.0},
479
+ }
480
+ else:
481
+ margin = None
482
+ stats_margin = {"value": None, "status": "insufficient_classes"}
483
+
484
+ if margin is not None and margin < 0.05:
485
+ warnings.append(f"prob_margin: value={margin:.4f}, target≥0.0500")
486
+
487
+ stats = dict(stats)
488
+ stats["warnings"] = warnings
489
+ stats["top_confidence"] = float(confidences[0]) if confidences.size else None
490
+ stats["second_confidence"] = float(confidences[1]) if confidences.size > 1 else None
491
+ stats["prob_margin"] = stats_margin
492
+ return stats
493
+
494
+
495
+ def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance):
496
+ ensure_model_loaded()
497
+ combined_canvas = compose_dual_canvas(left_canvas, right_canvas)
498
+ if combined_canvas is None:
499
+ blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
500
+ empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
501
+ empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
502
+ diagnostics = {"warnings": ["Draw both digits to see diagnostics."]}
503
+ return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
504
+
505
+ result = preprocess_image(
506
+ combined_canvas,
507
+ stroke_scale=stroke_scale,
508
+ auto_balance=bool(auto_balance),
509
+ )
510
+ if result is None:
511
+ blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
512
+ empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
513
+ empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
514
+ diagnostics = {"warnings": ["Draw a number to see diagnostics."]}
515
+ return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
516
+
517
+ standardized_variants, preview, mean_diff, diagnostics = result
518
+
519
+ variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
520
+ cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
521
+ logits_matrix = cache["Z_fc2"]
522
+ avg_logits = np.mean(logits_matrix, axis=1, keepdims=True)
523
+ probs = softmax(avg_logits)
524
+
525
+ pred = int(get_predictions(probs)[0])
526
+
527
+ prob_rows = [[f"{i:02d}", float(probs[i, 0])] for i in range(OUTPUT_CLASSES)]
528
+ prob_rows.sort(key=lambda r: r[1], reverse=True)
529
+ diagnostics = enrich_diagnostics(diagnostics, probs)
530
+ diagnostics["variants_used"] = int(probs_matrix.shape[1])
531
+ diagnostics["variant_top_confidences"] = [
532
+ float(probs_matrix[pred, idx]) for idx in range(probs_matrix.shape[1])
533
+ ]
534
+ return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
535
+
536
+
537
+ with gr.Blocks() as demo:
538
+ gr.Markdown(
539
+ """
540
+ # Elliot's MNIST-100 Classifier
541
+ Draw a two-digit number (00-99). Use the left canvas for the tens digit and the right canvas for the ones digit. The model will predict the number, show the top class probabilities, and display diagnostics for the processed input.
542
+ """
543
+ )
544
+
545
+ with gr.Row():
546
+ with gr.Column(scale=1):
547
+ with gr.Row():
548
+ left_canvas = gr.Sketchpad(label="Tens Digit")
549
+ right_canvas = gr.Sketchpad(label="Ones Digit")
550
+ stroke_slider = gr.Slider(
551
+ minimum=0.3,
552
+ maximum=1.2,
553
+ value=1.0,
554
+ step=0.05,
555
+ label="Stroke Intensity (scale)",
556
+ )
557
+ auto_balance = gr.Checkbox(
558
+ value=True,
559
+ label="Auto Balance Stroke Thickness",
560
+ info="Automatically rescales the digit to match training mass and brightness.",
561
+ )
562
+
563
+ with gr.Column(scale=1):
564
+ pred_box = gr.Number(label="Predicted Number", precision=0, value=None)
565
+ prob_table = gr.Dataframe(
566
+ label="Class Probabilities",
567
+ headers=["class", "prob"],
568
+ datatype=["str", "number"],
569
+ interactive=False,
570
+ )
571
+ preview = gr.Image(label="Model Input Preview (28x56)", image_mode="L")
572
+ mean_diff_view = gr.Image(label="Difference vs Training Mean", image_mode="L")
573
+ diagnostics_box = gr.Code(label="Diagnostics (JSON)", language="json")
574
+ predict_btn = gr.Button("Predict", variant="primary")
575
+ clear_btn = gr.ClearButton(
576
+ [
577
+ left_canvas,
578
+ right_canvas,
579
+ stroke_slider,
580
+ auto_balance,
581
+ pred_box,
582
+ prob_table,
583
+ preview,
584
+ mean_diff_view,
585
+ diagnostics_box,
586
+ ]
587
+ )
588
+
589
+ predict_btn.click(
590
+ fn=predict_number,
591
+ inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
592
+ outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
593
+ )
594
+ left_canvas.change(
595
+ fn=predict_number,
596
+ inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
597
+ outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
598
+ )
599
+ right_canvas.change(
600
+ fn=predict_number,
601
+ inputs=[left_canvas, right_canvas, stroke_slider, auto_balance],
602
+ outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
603
+ )
604
+
605
+
606
+ if __name__ == "__main__":
607
+ space_env = os.getenv("SPACE_ID")
608
+ if space_env:
609
+ demo.launch(show_api=False)
610
+ else:
611
+ demo.launch(server_name="0.0.0.0", share=True, show_api=False)
612
+ def _disable_gradio_api_schema(*_args, **_kwargs):
613
+ """Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
614
+ return {}
615
+
616
+
617
+ gr_routes.api_info = _disable_gradio_api_schema