Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import gradio as gr | |
| import gradio.routes as gr_routes | |
| from PIL import Image, ImageOps | |
| from pathlib import Path | |
| import importlib.util | |
| import json | |
| OUTPUT_CLASSES = 100 | |
| TARGET_HEIGHT, TARGET_WIDTH = 28, 56 | |
| STD_FLOOR = 1e-8 | |
| METRIC_TARGETS = { | |
| "mass_fraction": (0.08, 0.35), | |
| "stroke_density": (0.12, 0.65), | |
| "center_offset": (0.0, 8.0), | |
| "mean_abs_z_score": (0.0, 2.5), | |
| "max_abs_z_score": (0.0, 6.0), | |
| "std_abs_z_score": (0.0, 1.5), | |
| } | |
| def _load_training_module(): | |
| module_path = Path(__file__).resolve().parent / "training-100.py" | |
| spec = importlib.util.spec_from_file_location("mnist100_training", module_path) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module | |
| training_mod = _load_training_module() | |
| forward_prop = training_mod.forward_prop | |
| get_predictions = training_mod.get_predictions | |
| softmax = training_mod.softmax | |
| # Detect if running on Hugging Face Spaces | |
| IS_SPACE = bool(os.getenv("SPACE_ID")) | |
| def _metric_status(name, value): | |
| target = METRIC_TARGETS.get(name) | |
| status = "not_tracked" | |
| target_dict = None | |
| if target is not None: | |
| low, high = target | |
| target_dict = {"min": low, "max": high} | |
| if value is None or np.isnan(value): | |
| status = "invalid" | |
| elif low <= value <= high: | |
| status = "ok" | |
| else: | |
| status = "out_of_range" | |
| return status, target_dict | |
| def load_trained_artifacts(model_path=None): | |
| base_dir = Path(__file__).resolve().parent | |
| if model_path is None: | |
| resolved_path = base_dir / "archive" / "trained_model_mnist100.npz" | |
| else: | |
| candidate = Path(model_path) | |
| resolved_path = candidate if candidate.is_absolute() else base_dir / candidate | |
| if not resolved_path.exists(): | |
| raise RuntimeError( | |
| f"Model file '{resolved_path}' not found. Train the MNIST-100 model first by running 'python training-100.py'." | |
| ) | |
| loaded = np.load(resolved_path) | |
| params = {key: loaded[key] for key in loaded.files if key not in {"mean", "std"}} | |
| mean = loaded["mean"] | |
| std = loaded["std"] | |
| return params, mean, std | |
| params, mean, std = None, None, None | |
| def ensure_model_loaded(): | |
| global params, mean, std | |
| if params is None or mean is None or std is None: | |
| params, mean, std = load_trained_artifacts() | |
| def extract_canvas_array(img_input): | |
| if img_input is None: | |
| return None | |
| if isinstance(img_input, dict): | |
| for key in ("image", "composite", "background", "value"): | |
| payload = img_input.get(key) | |
| if payload is not None: | |
| img_input = payload | |
| break | |
| else: | |
| return None | |
| if isinstance(img_input, Image.Image): | |
| return img_input | |
| if isinstance(img_input, np.ndarray): | |
| arr_in = img_input | |
| if arr_in.dtype != np.uint8: | |
| max_val = float(arr_in.max()) if arr_in.size else 1.0 | |
| if max_val <= 1.5: | |
| arr_in = (arr_in * 255.0).clip(0, 255).astype(np.uint8) | |
| else: | |
| arr_in = np.clip(arr_in, 0, 255).astype(np.uint8) | |
| if arr_in.ndim == 3 and arr_in.shape[2] == 4: | |
| return Image.fromarray(arr_in, mode="RGBA") | |
| return Image.fromarray(arr_in) | |
| return None | |
| def shift_with_zero_pad(arr, shift_y=0, shift_x=0): | |
| if shift_y == 0 and shift_x == 0: | |
| return arr | |
| rolled = np.roll(arr, shift=shift_y, axis=0) | |
| rolled = np.roll(rolled, shift=shift_x, axis=1) | |
| out = rolled.copy() | |
| if shift_y > 0: | |
| out[:shift_y, :] = 0.0 | |
| elif shift_y < 0: | |
| out[shift_y:, :] = 0.0 | |
| if shift_x > 0: | |
| out[:, :shift_x] = 0.0 | |
| elif shift_x < 0: | |
| out[:, shift_x:] = 0.0 | |
| return out | |
| def dilate_binary_like(arr, radius=1): | |
| # Vectorized dilation via max over shifted windows (3x3 when radius=1) | |
| if radius != 1: | |
| # Fallback to radius=1 behavior for simplicity/perf | |
| radius = 1 | |
| shifts = [] | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| shifts.append(shift_with_zero_pad(arr, dy, dx)) | |
| stacked = np.stack(shifts, axis=0) | |
| return np.max(stacked, axis=0) | |
| def erode_binary_like(arr, radius=1): | |
| # Vectorized erosion via min over shifted windows (3x3 when radius=1) | |
| if radius != 1: | |
| radius = 1 | |
| shifts = [] | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| shifts.append(shift_with_zero_pad(arr, dy, dx)) | |
| stacked = np.stack(shifts, axis=0) | |
| return np.min(stacked, axis=0) | |
| def generate_inference_variants(arr, *, fast: bool = False): | |
| variants = [] | |
| if fast: | |
| # Space-optimized: cardinal shifts plus light morphology (6 variants) | |
| for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)): | |
| variants.append(shift_with_zero_pad(arr, dy, dx)) | |
| variants.append(dilate_binary_like(arr, radius=1)) | |
| variants.append(erode_binary_like(arr, radius=1)) | |
| return variants | |
| # Full set: 8 shifts + morphology | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| if dy == 0 and dx == 0: | |
| continue | |
| variants.append(shift_with_zero_pad(arr, dy, dx)) | |
| variants.append(dilate_binary_like(arr, radius=1)) | |
| variants.append(erode_binary_like(arr, radius=1)) | |
| return variants | |
| def _auto_balance_stroke(arr: np.ndarray, *, target_mass_fraction: float, clamp: tuple[float, float]): | |
| mass_fraction = float(arr.sum() / (TARGET_HEIGHT * TARGET_WIDTH)) | |
| if mass_fraction <= 1e-6: | |
| return arr, 1.0, mass_fraction | |
| scale = np.sqrt(target_mass_fraction / mass_fraction) | |
| min_scale, max_scale = clamp | |
| scale = float(np.clip(scale, min_scale, max_scale)) | |
| adjusted = np.clip(arr * scale, 0.0, 1.0) | |
| new_mass_fraction = float(adjusted.sum() / (TARGET_HEIGHT * TARGET_WIDTH)) | |
| return adjusted, scale, new_mass_fraction | |
| def _valley_split(mask: np.ndarray) -> int | None: | |
| # Find a vertical seam (column) with minimal foreground to split two digits | |
| H, W = mask.shape | |
| if W < 8: | |
| return None | |
| col_sums = mask.sum(axis=0) | |
| start = max(1, int(W * 0.25)) | |
| end = min(W - 1, int(W * 0.75)) | |
| if end <= start: | |
| start, end = 1, W - 1 | |
| idx = int(np.argmin(col_sums[start:end])) + start | |
| left_mass = int(col_sums[:idx].sum()) | |
| right_mass = int(col_sums[idx:].sum()) | |
| if left_mass > 50 and right_mass > 50: | |
| return idx | |
| return None | |
| def _connected_components(mask: np.ndarray): | |
| H, W = mask.shape | |
| visited = np.zeros_like(mask, dtype=bool) | |
| comps = [] | |
| for y in range(H): | |
| row = mask[y] | |
| for x in range(W): | |
| if row[x] and not visited[y, x]: | |
| stack = [(y, x)] | |
| visited[y, x] = True | |
| ys, xs = [], [] | |
| while stack: | |
| cy, cx = stack.pop() | |
| ys.append(cy) | |
| xs.append(cx) | |
| # 4-connectivity | |
| if cy > 0 and mask[cy - 1, cx] and not visited[cy - 1, cx]: | |
| visited[cy - 1, cx] = True | |
| stack.append((cy - 1, cx)) | |
| if cy + 1 < H and mask[cy + 1, cx] and not visited[cy + 1, cx]: | |
| visited[cy + 1, cx] = True | |
| stack.append((cy + 1, cx)) | |
| if cx > 0 and mask[cy, cx - 1] and not visited[cy, cx - 1]: | |
| visited[cy, cx - 1] = True | |
| stack.append((cy, cx - 1)) | |
| if cx + 1 < W and mask[cy, cx + 1] and not visited[cy, cx + 1]: | |
| visited[cy, cx + 1] = True | |
| stack.append((cy, cx + 1)) | |
| y1, y2 = min(ys), max(ys) + 1 | |
| x1, x2 = min(xs), max(xs) + 1 | |
| comps.append({"bbox": (y1, y2, x1, x2), "size": len(ys)}) | |
| return comps | |
| def canonicalize_digit_28x28(arr: np.ndarray) -> np.ndarray: | |
| # Input arr: float32 in [0,1], arbitrary HxW; output: 28x28 centered tile | |
| if arr.size == 0: | |
| return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32) | |
| thr = arr > 0.05 | |
| if not thr.any(): | |
| return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32) | |
| ys, xs = np.where(thr) | |
| y1, y2 = ys.min(), ys.max() + 1 | |
| x1, x2 = xs.min(), xs.max() + 1 | |
| # small padding | |
| pad = 2 | |
| y1 = max(0, y1 - pad) | |
| x1 = max(0, x1 - pad) | |
| y2 = min(arr.shape[0], y2 + pad) | |
| x2 = min(arr.shape[1], x2 + pad) | |
| crop = arr[y1:y2, x1:x2] | |
| h, w = crop.shape | |
| if h == 0 or w == 0: | |
| return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32) | |
| # resize shorter side to 20 | |
| if h >= w: | |
| new_h = 20 | |
| new_w = max(1, int(round(w * (20.0 / h)))) | |
| else: | |
| new_w = 20 | |
| new_h = max(1, int(round(h * (20.0 / w)))) | |
| small = Image.fromarray((crop * 255.0).astype(np.uint8)).resize( | |
| (new_w, new_h), Image.Resampling.LANCZOS | |
| ) | |
| tile = Image.new("L", (TARGET_HEIGHT, TARGET_HEIGHT), color=0) | |
| # paste centered | |
| top = (TARGET_HEIGHT - new_h) // 2 | |
| left = (TARGET_HEIGHT - new_w) // 2 | |
| tile.paste(small, (left, top)) | |
| tile_arr = np.array(tile, dtype=np.float32) / 255.0 | |
| # center-of-mass shift to exact center | |
| mass = tile_arr | |
| tot = float(mass.sum()) | |
| if tot > 1e-6: | |
| gy, gx = np.indices(mass.shape) | |
| cy = float((gy * mass).sum() / tot) | |
| cx = float((gx * mass).sum() / tot) | |
| ideal = (TARGET_HEIGHT - 1) / 2.0 | |
| dy = int(np.clip(round(ideal - cy), -2, 2)) | |
| dx = int(np.clip(round(ideal - cx), -2, 2)) | |
| if dy != 0 or dx != 0: | |
| tile_arr = shift_with_zero_pad(tile_arr, dy, dx) | |
| return tile_arr.astype(np.float32, copy=False) | |
| def compose_from_single_canvas(img_input): | |
| img = extract_canvas_array(img_input) | |
| if img is None: | |
| return None, {"warnings": ["No image provided."]} | |
| try: | |
| bands = img.getbands() | |
| except Exception: | |
| bands = () | |
| if "A" in bands: | |
| rgba = img.convert("RGBA") | |
| white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255)) | |
| img = Image.alpha_composite(white_bg, rgba).convert("RGB") | |
| gray = img.convert("L") | |
| inv = ImageOps.invert(gray) | |
| arr_u8 = np.array(inv, dtype=np.uint8) | |
| mask = arr_u8 > 10 | |
| if not mask.any(): | |
| return None, {"warnings": ["Empty drawing detected."]} | |
| # Global bbox trim for speed | |
| ys, xs = np.where(mask) | |
| y1, y2 = ys.min(), ys.max() + 1 | |
| x1, x2 = xs.min(), xs.max() + 1 | |
| pad = 4 | |
| y1 = max(0, y1 - pad) | |
| x1 = max(0, x1 - pad) | |
| y2 = min(arr_u8.shape[0], y2 + pad) | |
| x2 = min(arr_u8.shape[1], x2 + pad) | |
| arr_u8 = arr_u8[y1:y2, x1:x2] | |
| mask = mask[y1:y2, x1:x2] | |
| method = "valley" | |
| split = _valley_split(mask) | |
| left_arr = right_arr = None | |
| if split is not None: | |
| left_area = arr_u8[:, :split] | |
| right_area = arr_u8[:, split:] | |
| if (left_area > 10).any(): | |
| l_ys, l_xs = np.where(left_area > 10) | |
| ly1, ly2 = l_ys.min(), l_ys.max() + 1 | |
| lx1, lx2 = l_xs.min(), l_xs.max() + 1 | |
| left_arr = left_area[ly1:ly2, lx1:lx2] | |
| if (right_area > 10).any(): | |
| r_ys, r_xs = np.where(right_area > 10) | |
| ry1, ry2 = r_ys.min(), r_ys.max() + 1 | |
| rx1, rx2 = r_xs.min(), r_xs.max() + 1 | |
| right_arr = right_area[ry1:ry2, rx1:rx2] | |
| else: | |
| method = "components" | |
| comps = _connected_components(mask) | |
| if len(comps) >= 2: | |
| comps.sort(key=lambda c: c["size"], reverse=True) | |
| a, b = comps[0], comps[1] | |
| # sort left/right by x1 | |
| if a["bbox"][2] <= b["bbox"][2]: | |
| left_bbox, right_bbox = a["bbox"], b["bbox"] | |
| else: | |
| left_bbox, right_bbox = b["bbox"], a["bbox"] | |
| ly1, ly2, lx1, lx2 = left_bbox | |
| ry1, ry2, rx1, rx2 = right_bbox | |
| left_arr = arr_u8[ly1:ly2, lx1:lx2] | |
| right_arr = arr_u8[ry1:ry2, rx1:rx2] | |
| else: | |
| # Fallback: split the single bbox in half | |
| method = "fallback_center_split" | |
| W = arr_u8.shape[1] | |
| split = W // 2 | |
| left_arr = arr_u8[:, :split] | |
| right_arr = arr_u8[:, split:] | |
| # Convert to float and canonicalize per digit | |
| left_tile = canonicalize_digit_28x28((left_arr.astype(np.float32) / 255.0) if left_arr is not None else np.zeros((1, 1), dtype=np.float32)) | |
| right_tile = canonicalize_digit_28x28((right_arr.astype(np.float32) / 255.0) if right_arr is not None else np.zeros((1, 1), dtype=np.float32)) | |
| composed = np.concatenate([left_tile, right_tile], axis=1) | |
| diag = { | |
| "segmentation": { | |
| "method": method, | |
| "canvas_crop": {"top": int(y1), "bottom": int(y2), "left": int(x1), "right": int(x2)}, | |
| } | |
| } | |
| return composed.astype(np.float32, copy=False), diag | |
| def preprocess_composed_28x56(arr_28x56: np.ndarray, stroke_scale: float = 1.0, *, extra_diag: dict | None = None): | |
| ensure_model_loaded() | |
| if arr_28x56 is None: | |
| return None | |
| arr_resized = np.clip(arr_28x56.astype(np.float32), 0.0, 1.0) | |
| mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH) | |
| std_safe = np.maximum(std, STD_FLOOR) | |
| stroke_scale = float(stroke_scale) | |
| stroke_scale = max(0.3, min(stroke_scale, 1.5)) | |
| arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0) | |
| auto_balance_scale = 1.0 | |
| pre_balance_mass_fraction = float(arr_resized.mean()) | |
| target_mass = float(mean.mean()) | |
| arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke( | |
| arr_resized, | |
| target_mass_fraction=target_mass, | |
| clamp=(0.6, 1.6), | |
| ) | |
| # We already centered per 28x28 tile; skip whole-image recentering here | |
| arr_centered = arr_resized | |
| augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)] | |
| augmented_standardized = [] | |
| for arr in augmented_arrays: | |
| z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe | |
| z = np.clip(z, -8.0, 8.0) | |
| augmented_standardized.append(z.astype(np.float32, copy=False)) | |
| mean_diff = np.abs(arr_centered - mean_image) | |
| mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8) | |
| diagnostics = compute_diagnostics( | |
| arr_centered, | |
| None, | |
| arr_centered.shape, | |
| mean_image, | |
| augmented_standardized[0], | |
| std_safe, | |
| ) | |
| diagnostics["applied_auto_balance"] = { | |
| "enabled": True, | |
| "scale": float(auto_balance_scale), | |
| "mass_fraction_after": float(balanced_mass_fraction), | |
| "mass_fraction_before": float(pre_balance_mass_fraction), | |
| "target_mass_fraction": float(target_mass), | |
| } | |
| if extra_diag: | |
| diagnostics.update(extra_diag) | |
| return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics | |
| def preprocess_image(img_input, stroke_scale: float = 1.0): | |
| ensure_model_loaded() | |
| img = extract_canvas_array(img_input) | |
| if img is None: | |
| return None | |
| try: | |
| bands = img.getbands() | |
| except Exception: | |
| bands = () | |
| if "A" in bands: | |
| rgba = img.convert("RGBA") | |
| white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255)) | |
| img = Image.alpha_composite(white_bg, rgba).convert("RGB") | |
| img = img.convert("L") | |
| img = ImageOps.invert(img) | |
| arr_u8 = np.array(img, dtype=np.uint8) | |
| original_canvas_shape = arr_u8.shape | |
| coords = np.column_stack(np.where(arr_u8 > 10)) | |
| bbox = None | |
| if coords.size > 0: | |
| y_min, x_min = coords.min(axis=0) | |
| y_max, x_max = coords.max(axis=0) + 1 | |
| pad = 4 | |
| y_min = max(0, y_min - pad) | |
| x_min = max(0, x_min - pad) | |
| y_max = min(arr_u8.shape[0], y_max + pad) | |
| x_max = min(arr_u8.shape[1], x_max + pad) | |
| bbox = (int(y_min), int(y_max), int(x_min), int(x_max)) | |
| arr_u8 = arr_u8[y_min:y_max, x_min:x_max] | |
| if arr_u8.size == 0: | |
| return None | |
| h, w = arr_u8.shape | |
| target_ratio = TARGET_WIDTH / TARGET_HEIGHT | |
| if h == 0 or w == 0: | |
| return None | |
| current_ratio = w / h if h else target_ratio | |
| if current_ratio > target_ratio: | |
| new_height = int(round(w / target_ratio)) | |
| pad_total = max(new_height - h, 0) | |
| pad_top = pad_total // 2 | |
| pad_bottom = pad_total - pad_top | |
| pad_left = pad_right = 0 | |
| else: | |
| new_width = int(round(h * target_ratio)) | |
| pad_total = max(new_width - w, 0) | |
| pad_left = pad_total // 2 | |
| pad_right = pad_total - pad_left | |
| pad_top = pad_bottom = 0 | |
| arr_padded = np.pad( | |
| arr_u8, | |
| ((pad_top, pad_bottom), (pad_left, pad_right)), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| resized = Image.fromarray(arr_padded).resize( | |
| (TARGET_WIDTH, TARGET_HEIGHT), Image.Resampling.LANCZOS | |
| ) | |
| arr_resized = np.array(resized, dtype=np.float32) / 255.0 | |
| mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH) | |
| std_safe = np.maximum(std, STD_FLOOR) | |
| stroke_scale = float(stroke_scale) | |
| stroke_scale = max(0.3, min(stroke_scale, 1.5)) | |
| arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0) | |
| auto_balance_scale = 1.0 | |
| # Match the dataset's global mean intensity (more faithful than a fixed midpoint) | |
| pre_balance_mass_fraction = float(arr_resized.mean()) | |
| target_mass = float(mean.mean()) | |
| arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke( | |
| arr_resized, | |
| target_mass_fraction=target_mass, | |
| clamp=(0.6, 1.6), | |
| ) | |
| # Light recentering by center-of-mass to reduce sensitivity to placement | |
| mass = arr_resized | |
| total_intensity = float(mass.sum()) | |
| arr_centered = arr_resized | |
| if total_intensity > 1e-6: | |
| gy, gx = np.indices(mass.shape) | |
| cy = float((gy * mass).sum() / total_intensity) | |
| cx = float((gx * mass).sum() / total_intensity) | |
| ideal_cy = (TARGET_HEIGHT - 1) / 2.0 | |
| ideal_cx = (TARGET_WIDTH - 1) / 2.0 | |
| dy = int(np.clip(round(ideal_cy - cy), -2, 2)) | |
| dx = int(np.clip(round(ideal_cx - cx), -2, 2)) | |
| if dy != 0 or dx != 0: | |
| arr_centered = shift_with_zero_pad(arr_resized, dy, dx) | |
| augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)] | |
| # Standardize each variant and clip to tame outliers for stable inference | |
| augmented_standardized = [] | |
| for arr in augmented_arrays: | |
| z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe | |
| z = np.clip(z, -8.0, 8.0) | |
| augmented_standardized.append(z.astype(np.float32, copy=False)) | |
| mean_diff = np.abs(arr_centered - mean_image) | |
| mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8) | |
| diagnostics = compute_diagnostics( | |
| arr_centered, | |
| bbox, | |
| original_canvas_shape, | |
| mean_image, | |
| augmented_standardized[0], | |
| std_safe, | |
| ) | |
| diagnostics["applied_auto_balance"] = { | |
| "enabled": True, | |
| "scale": float(auto_balance_scale), | |
| "mass_fraction_after": float(balanced_mass_fraction), | |
| "mass_fraction_before": float(pre_balance_mass_fraction), | |
| "target_mass_fraction": float(target_mass), | |
| } | |
| return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics | |
| def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe): | |
| mass = arr_float | |
| total_intensity = float(mass.sum()) | |
| mass_threshold = mass > 0.05 | |
| if mass_threshold.any(): | |
| ys, xs = np.where(mass_threshold) | |
| bbox_est = (int(ys.min()), int(ys.max()) + 1, int(xs.min()), int(xs.max()) + 1) | |
| else: | |
| bbox_est = None | |
| cy = cx = None | |
| if total_intensity > 1e-6: | |
| grid_y, grid_x = np.indices(mass.shape) | |
| weighted_sum = mass.sum() | |
| cy = float((grid_y * mass).sum() / weighted_sum) | |
| cx = float((grid_x * mass).sum() / weighted_sum) | |
| bbox_use = bbox_est or bbox | |
| if bbox_use: | |
| top, bottom, left, right = bbox_use | |
| height = bottom - top | |
| width = right - left | |
| bbox_area = height * width | |
| bbox_metrics = { | |
| "top": top, | |
| "bottom": bottom, | |
| "left": left, | |
| "right": right, | |
| "height": height, | |
| "width": width, | |
| "aspect_ratio": float(width / height) if height else None, | |
| "area": bbox_area, | |
| "area_ratio": float(bbox_area / (TARGET_HEIGHT * TARGET_WIDTH)) if bbox_area else 0.0, | |
| } | |
| else: | |
| bbox_metrics = { | |
| "top": None, | |
| "bottom": None, | |
| "left": None, | |
| "right": None, | |
| "height": 0, | |
| "width": 0, | |
| "aspect_ratio": None, | |
| "area": 0, | |
| "area_ratio": 0.0, | |
| } | |
| density = 0.0 | |
| bbox_area = bbox_metrics.get("area", 0) | |
| if bbox_area: | |
| density = float(total_intensity / bbox_area) | |
| center_offset = None | |
| if cy is not None and cx is not None: | |
| ideal_cy = (TARGET_HEIGHT - 1) / 2.0 | |
| ideal_cx = (TARGET_WIDTH - 1) / 2.0 | |
| center_offset = float(np.sqrt((cy - ideal_cy) ** 2 + (cx - ideal_cx) ** 2)) | |
| standardized_flat = standardized.flatten() | |
| mean_flat = mean_image.flatten() | |
| arr_flat = arr_float.flatten() | |
| std_flat = std_safe.flatten() | |
| norm_input = np.linalg.norm(arr_flat) | |
| norm_mean = np.linalg.norm(mean_flat) | |
| cosine_similarity = None | |
| if norm_input > 0.0 and norm_mean > 0.0: | |
| cosine_similarity = float(np.dot(arr_flat, mean_flat) / (norm_input * norm_mean)) | |
| mean_abs_z = float(np.mean(np.abs(standardized_flat))) | |
| max_abs_z = float(np.max(np.abs(standardized_flat))) | |
| std_of_z = float(np.std(standardized_flat)) | |
| low_var_mask = std_flat <= STD_FLOOR + 1e-12 | |
| activated_low_var = int(np.count_nonzero(low_var_mask & (np.abs(arr_flat - mean_flat) > 1e-3))) | |
| stats = { | |
| "total_intensity": total_intensity, | |
| "mass_fraction": float(total_intensity / (TARGET_HEIGHT * TARGET_WIDTH)), | |
| "center_of_mass": {"row": cy, "col": cx}, | |
| "center_offset": center_offset, | |
| "bbox": bbox_metrics, | |
| "original_canvas_shape": original_shape, | |
| "stroke_density": density, | |
| "warnings": [], | |
| "mean_intensity": float(arr_float.mean()), | |
| "pixel_intensity_range": { | |
| "min": float(arr_float.min()), | |
| "max": float(arr_float.max()), | |
| }, | |
| "cosine_similarity_vs_mean": cosine_similarity, | |
| "mean_abs_z_score": mean_abs_z, | |
| "max_abs_z_score": max_abs_z, | |
| "std_abs_z_score": std_of_z, | |
| "low_variance_pixels_triggered": activated_low_var, | |
| "low_variance_threshold": STD_FLOOR, | |
| "low_variance_pixels_fraction": float(activated_low_var / max(1, int(low_var_mask.sum()))), | |
| } | |
| if mean_image is not None: | |
| stats["distance_from_mean"] = float(np.linalg.norm(arr_float - mean_image)) | |
| metric_checks = {} | |
| for metric_name in ( | |
| "mass_fraction", | |
| "stroke_density", | |
| "center_offset", | |
| "mean_abs_z_score", | |
| "max_abs_z_score", | |
| "std_abs_z_score", | |
| ): | |
| value = stats.get(metric_name) | |
| if value is not None: | |
| value = float(value) | |
| status, target_dict = _metric_status(metric_name, value) | |
| entry = {"value": value, "status": status} | |
| if target_dict is not None: | |
| entry["target"] = target_dict | |
| metric_checks[metric_name] = entry | |
| stats["metric_checks"] = metric_checks | |
| return stats | |
| def enrich_diagnostics(stats, probs): | |
| warnings = [] | |
| bbox = stats.get("bbox", {}) | |
| metric_checks = stats.get("metric_checks", {}) | |
| for name, info in metric_checks.items(): | |
| if info.get("status") == "out_of_range": | |
| target = info.get("target") | |
| value = info.get("value") | |
| value_str = "None" if value is None else f"{value:.4f}" | |
| if target is not None: | |
| warnings.append( | |
| f"{name}: value={value_str}, target=[{target['min']:.4f},{target['max']:.4f}]" | |
| ) | |
| else: | |
| warnings.append(f"{name}: value={value_str}") | |
| aspect_ratio = bbox.get("aspect_ratio") | |
| if aspect_ratio is not None and (aspect_ratio < 1.0 or aspect_ratio > 3.5): | |
| warnings.append(f"aspect_ratio: value={aspect_ratio:.4f}, expected≈[1.00,3.50]") | |
| confidences = np.sort(probs.flatten())[::-1] | |
| if confidences.size >= 2: | |
| margin = confidences[0] - confidences[1] | |
| stats_margin = { | |
| "value": float(margin), | |
| "status": "ok" if margin >= 0.05 else "low_margin", | |
| "target": {"min": 0.05, "max": 1.0}, | |
| } | |
| else: | |
| margin = None | |
| stats_margin = {"value": None, "status": "insufficient_classes"} | |
| if margin is not None and margin < 0.05: | |
| warnings.append(f"prob_margin: value={margin:.4f}, target≥0.0500") | |
| stats = dict(stats) | |
| stats["warnings"] = warnings | |
| stats["top_confidence"] = float(confidences[0]) if confidences.size else None | |
| stats["second_confidence"] = float(confidences[1]) if confidences.size > 1 else None | |
| stats["prob_margin"] = stats_margin | |
| return stats | |
| def predict_number(main_canvas): | |
| ensure_model_loaded() | |
| composed, seg_diag = compose_from_single_canvas(main_canvas) | |
| if composed is None: | |
| blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)} | |
| empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| diagnostics = {"warnings": ["Draw two digits to see diagnostics."]} | |
| return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2) | |
| result = preprocess_composed_28x56( | |
| composed, | |
| extra_diag=seg_diag, | |
| ) | |
| if result is None: | |
| blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)} | |
| empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| diagnostics = {"warnings": ["Draw two digits to see diagnostics."]} | |
| return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2) | |
| standardized_variants, preview, mean_diff, diagnostics = result | |
| variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False) | |
| cache, probs_matrix = forward_prop(variants_matrix, params, training=False) | |
| # Average probabilities across variants to reduce domination by any single variant | |
| probs = np.mean(probs_matrix, axis=1, keepdims=True) | |
| pred = int(get_predictions(probs)[0]) | |
| prob_rows = [[f"{i:02d}", float(probs[i, 0])] for i in range(OUTPUT_CLASSES)] | |
| prob_rows.sort(key=lambda r: r[1], reverse=True) | |
| diagnostics = enrich_diagnostics(diagnostics, probs) | |
| diagnostics["variants_used"] = int(probs_matrix.shape[1]) | |
| diagnostics["variant_top_confidences"] = [ | |
| float(probs_matrix[pred, idx]) for idx in range(probs_matrix.shape[1]) | |
| ] | |
| return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Elliot's MNIST-100 Classifier | |
| Draw a two-digit number (00-99) on the single canvas. The app uses the Convolutional Neural Network to predict the number accurately. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1) as left_col: | |
| main_canvas = gr.Sketchpad(label="Draw Two Digits (00–99)") | |
| with gr.Column(scale=1): | |
| pred_box = gr.Number(label="Predicted Number", precision=0, value=None) | |
| prob_table = gr.Dataframe( | |
| label="Class Probabilities", | |
| headers=["class", "prob"], | |
| datatype=["str", "number"], | |
| interactive=False, | |
| ) | |
| preview = gr.Image(label="Model Input Preview (28x56)", image_mode="L") | |
| mean_diff_view = gr.Image(label="Difference vs Training Mean", image_mode="L") | |
| diagnostics_box = gr.Code(label="Diagnostics (JSON)", language="json") | |
| # Place buttons under the canvas, but wire them to clear outputs as well | |
| with left_col: | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| clear_btn = gr.ClearButton( | |
| [ | |
| main_canvas, | |
| pred_box, | |
| prob_table, | |
| preview, | |
| mean_diff_view, | |
| diagnostics_box, | |
| ] | |
| ) | |
| predict_btn.click( | |
| fn=predict_number, | |
| inputs=[main_canvas], | |
| outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box], | |
| ) | |
| # On Spaces, avoid per-stroke inference to prevent event floods | |
| if not IS_SPACE: | |
| main_canvas.change( | |
| fn=predict_number, | |
| inputs=[main_canvas], | |
| outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box], | |
| ) | |
| if __name__ == "__main__": | |
| space_env = os.getenv("SPACE_ID") | |
| def _queue_app(blocks): | |
| try: | |
| return blocks.queue(concurrency_count=1) | |
| except TypeError: | |
| # Older Gradio versions don't support the argument | |
| try: | |
| return blocks.queue() | |
| except Exception: | |
| return blocks | |
| app_to_launch = _queue_app(demo) | |
| if space_env: | |
| app_to_launch.launch(show_api=False) | |
| else: | |
| app_to_launch.launch(server_name="0.0.0.0", share=True, show_api=False) | |
| def _disable_gradio_api_schema(*_args, **_kwargs): | |
| """Work around Gradio schema bug on Python 3.13 by returning empty metadata.""" | |
| return {} | |
| gr_routes.api_info = _disable_gradio_api_schema | |