Eli181927's picture
Update 2.CNN/app.py
77410f1 verified
import os
import numpy as np
import gradio as gr
import gradio.routes as gr_routes
from PIL import Image, ImageOps
from pathlib import Path
import importlib.util
import json
OUTPUT_CLASSES = 100
TARGET_HEIGHT, TARGET_WIDTH = 28, 56
STD_FLOOR = 1e-8
METRIC_TARGETS = {
"mass_fraction": (0.08, 0.35),
"stroke_density": (0.12, 0.65),
"center_offset": (0.0, 8.0),
"mean_abs_z_score": (0.0, 2.5),
"max_abs_z_score": (0.0, 6.0),
"std_abs_z_score": (0.0, 1.5),
}
def _load_training_module():
module_path = Path(__file__).resolve().parent / "training-100.py"
spec = importlib.util.spec_from_file_location("mnist100_training", module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
training_mod = _load_training_module()
forward_prop = training_mod.forward_prop
get_predictions = training_mod.get_predictions
softmax = training_mod.softmax
# Detect if running on Hugging Face Spaces
IS_SPACE = bool(os.getenv("SPACE_ID"))
def _metric_status(name, value):
target = METRIC_TARGETS.get(name)
status = "not_tracked"
target_dict = None
if target is not None:
low, high = target
target_dict = {"min": low, "max": high}
if value is None or np.isnan(value):
status = "invalid"
elif low <= value <= high:
status = "ok"
else:
status = "out_of_range"
return status, target_dict
def load_trained_artifacts(model_path=None):
base_dir = Path(__file__).resolve().parent
if model_path is None:
resolved_path = base_dir / "archive" / "trained_model_mnist100.npz"
else:
candidate = Path(model_path)
resolved_path = candidate if candidate.is_absolute() else base_dir / candidate
if not resolved_path.exists():
raise RuntimeError(
f"Model file '{resolved_path}' not found. Train the MNIST-100 model first by running 'python training-100.py'."
)
loaded = np.load(resolved_path)
params = {key: loaded[key] for key in loaded.files if key not in {"mean", "std"}}
mean = loaded["mean"]
std = loaded["std"]
return params, mean, std
params, mean, std = None, None, None
def ensure_model_loaded():
global params, mean, std
if params is None or mean is None or std is None:
params, mean, std = load_trained_artifacts()
def extract_canvas_array(img_input):
if img_input is None:
return None
if isinstance(img_input, dict):
for key in ("image", "composite", "background", "value"):
payload = img_input.get(key)
if payload is not None:
img_input = payload
break
else:
return None
if isinstance(img_input, Image.Image):
return img_input
if isinstance(img_input, np.ndarray):
arr_in = img_input
if arr_in.dtype != np.uint8:
max_val = float(arr_in.max()) if arr_in.size else 1.0
if max_val <= 1.5:
arr_in = (arr_in * 255.0).clip(0, 255).astype(np.uint8)
else:
arr_in = np.clip(arr_in, 0, 255).astype(np.uint8)
if arr_in.ndim == 3 and arr_in.shape[2] == 4:
return Image.fromarray(arr_in, mode="RGBA")
return Image.fromarray(arr_in)
return None
def shift_with_zero_pad(arr, shift_y=0, shift_x=0):
if shift_y == 0 and shift_x == 0:
return arr
rolled = np.roll(arr, shift=shift_y, axis=0)
rolled = np.roll(rolled, shift=shift_x, axis=1)
out = rolled.copy()
if shift_y > 0:
out[:shift_y, :] = 0.0
elif shift_y < 0:
out[shift_y:, :] = 0.0
if shift_x > 0:
out[:, :shift_x] = 0.0
elif shift_x < 0:
out[:, shift_x:] = 0.0
return out
def dilate_binary_like(arr, radius=1):
# Vectorized dilation via max over shifted windows (3x3 when radius=1)
if radius != 1:
# Fallback to radius=1 behavior for simplicity/perf
radius = 1
shifts = []
for dy in (-1, 0, 1):
for dx in (-1, 0, 1):
shifts.append(shift_with_zero_pad(arr, dy, dx))
stacked = np.stack(shifts, axis=0)
return np.max(stacked, axis=0)
def erode_binary_like(arr, radius=1):
# Vectorized erosion via min over shifted windows (3x3 when radius=1)
if radius != 1:
radius = 1
shifts = []
for dy in (-1, 0, 1):
for dx in (-1, 0, 1):
shifts.append(shift_with_zero_pad(arr, dy, dx))
stacked = np.stack(shifts, axis=0)
return np.min(stacked, axis=0)
def generate_inference_variants(arr, *, fast: bool = False):
variants = []
if fast:
# Space-optimized: cardinal shifts plus light morphology (6 variants)
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
variants.append(shift_with_zero_pad(arr, dy, dx))
variants.append(dilate_binary_like(arr, radius=1))
variants.append(erode_binary_like(arr, radius=1))
return variants
# Full set: 8 shifts + morphology
for dy in (-1, 0, 1):
for dx in (-1, 0, 1):
if dy == 0 and dx == 0:
continue
variants.append(shift_with_zero_pad(arr, dy, dx))
variants.append(dilate_binary_like(arr, radius=1))
variants.append(erode_binary_like(arr, radius=1))
return variants
def _auto_balance_stroke(arr: np.ndarray, *, target_mass_fraction: float, clamp: tuple[float, float]):
mass_fraction = float(arr.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
if mass_fraction <= 1e-6:
return arr, 1.0, mass_fraction
scale = np.sqrt(target_mass_fraction / mass_fraction)
min_scale, max_scale = clamp
scale = float(np.clip(scale, min_scale, max_scale))
adjusted = np.clip(arr * scale, 0.0, 1.0)
new_mass_fraction = float(adjusted.sum() / (TARGET_HEIGHT * TARGET_WIDTH))
return adjusted, scale, new_mass_fraction
def _valley_split(mask: np.ndarray) -> int | None:
# Find a vertical seam (column) with minimal foreground to split two digits
H, W = mask.shape
if W < 8:
return None
col_sums = mask.sum(axis=0)
start = max(1, int(W * 0.25))
end = min(W - 1, int(W * 0.75))
if end <= start:
start, end = 1, W - 1
idx = int(np.argmin(col_sums[start:end])) + start
left_mass = int(col_sums[:idx].sum())
right_mass = int(col_sums[idx:].sum())
if left_mass > 50 and right_mass > 50:
return idx
return None
def _connected_components(mask: np.ndarray):
H, W = mask.shape
visited = np.zeros_like(mask, dtype=bool)
comps = []
for y in range(H):
row = mask[y]
for x in range(W):
if row[x] and not visited[y, x]:
stack = [(y, x)]
visited[y, x] = True
ys, xs = [], []
while stack:
cy, cx = stack.pop()
ys.append(cy)
xs.append(cx)
# 4-connectivity
if cy > 0 and mask[cy - 1, cx] and not visited[cy - 1, cx]:
visited[cy - 1, cx] = True
stack.append((cy - 1, cx))
if cy + 1 < H and mask[cy + 1, cx] and not visited[cy + 1, cx]:
visited[cy + 1, cx] = True
stack.append((cy + 1, cx))
if cx > 0 and mask[cy, cx - 1] and not visited[cy, cx - 1]:
visited[cy, cx - 1] = True
stack.append((cy, cx - 1))
if cx + 1 < W and mask[cy, cx + 1] and not visited[cy, cx + 1]:
visited[cy, cx + 1] = True
stack.append((cy, cx + 1))
y1, y2 = min(ys), max(ys) + 1
x1, x2 = min(xs), max(xs) + 1
comps.append({"bbox": (y1, y2, x1, x2), "size": len(ys)})
return comps
def canonicalize_digit_28x28(arr: np.ndarray) -> np.ndarray:
# Input arr: float32 in [0,1], arbitrary HxW; output: 28x28 centered tile
if arr.size == 0:
return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
thr = arr > 0.05
if not thr.any():
return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
ys, xs = np.where(thr)
y1, y2 = ys.min(), ys.max() + 1
x1, x2 = xs.min(), xs.max() + 1
# small padding
pad = 2
y1 = max(0, y1 - pad)
x1 = max(0, x1 - pad)
y2 = min(arr.shape[0], y2 + pad)
x2 = min(arr.shape[1], x2 + pad)
crop = arr[y1:y2, x1:x2]
h, w = crop.shape
if h == 0 or w == 0:
return np.zeros((TARGET_HEIGHT, TARGET_HEIGHT), dtype=np.float32)
# resize shorter side to 20
if h >= w:
new_h = 20
new_w = max(1, int(round(w * (20.0 / h))))
else:
new_w = 20
new_h = max(1, int(round(h * (20.0 / w))))
small = Image.fromarray((crop * 255.0).astype(np.uint8)).resize(
(new_w, new_h), Image.Resampling.LANCZOS
)
tile = Image.new("L", (TARGET_HEIGHT, TARGET_HEIGHT), color=0)
# paste centered
top = (TARGET_HEIGHT - new_h) // 2
left = (TARGET_HEIGHT - new_w) // 2
tile.paste(small, (left, top))
tile_arr = np.array(tile, dtype=np.float32) / 255.0
# center-of-mass shift to exact center
mass = tile_arr
tot = float(mass.sum())
if tot > 1e-6:
gy, gx = np.indices(mass.shape)
cy = float((gy * mass).sum() / tot)
cx = float((gx * mass).sum() / tot)
ideal = (TARGET_HEIGHT - 1) / 2.0
dy = int(np.clip(round(ideal - cy), -2, 2))
dx = int(np.clip(round(ideal - cx), -2, 2))
if dy != 0 or dx != 0:
tile_arr = shift_with_zero_pad(tile_arr, dy, dx)
return tile_arr.astype(np.float32, copy=False)
def compose_from_single_canvas(img_input):
img = extract_canvas_array(img_input)
if img is None:
return None, {"warnings": ["No image provided."]}
try:
bands = img.getbands()
except Exception:
bands = ()
if "A" in bands:
rgba = img.convert("RGBA")
white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255))
img = Image.alpha_composite(white_bg, rgba).convert("RGB")
gray = img.convert("L")
inv = ImageOps.invert(gray)
arr_u8 = np.array(inv, dtype=np.uint8)
mask = arr_u8 > 10
if not mask.any():
return None, {"warnings": ["Empty drawing detected."]}
# Global bbox trim for speed
ys, xs = np.where(mask)
y1, y2 = ys.min(), ys.max() + 1
x1, x2 = xs.min(), xs.max() + 1
pad = 4
y1 = max(0, y1 - pad)
x1 = max(0, x1 - pad)
y2 = min(arr_u8.shape[0], y2 + pad)
x2 = min(arr_u8.shape[1], x2 + pad)
arr_u8 = arr_u8[y1:y2, x1:x2]
mask = mask[y1:y2, x1:x2]
method = "valley"
split = _valley_split(mask)
left_arr = right_arr = None
if split is not None:
left_area = arr_u8[:, :split]
right_area = arr_u8[:, split:]
if (left_area > 10).any():
l_ys, l_xs = np.where(left_area > 10)
ly1, ly2 = l_ys.min(), l_ys.max() + 1
lx1, lx2 = l_xs.min(), l_xs.max() + 1
left_arr = left_area[ly1:ly2, lx1:lx2]
if (right_area > 10).any():
r_ys, r_xs = np.where(right_area > 10)
ry1, ry2 = r_ys.min(), r_ys.max() + 1
rx1, rx2 = r_xs.min(), r_xs.max() + 1
right_arr = right_area[ry1:ry2, rx1:rx2]
else:
method = "components"
comps = _connected_components(mask)
if len(comps) >= 2:
comps.sort(key=lambda c: c["size"], reverse=True)
a, b = comps[0], comps[1]
# sort left/right by x1
if a["bbox"][2] <= b["bbox"][2]:
left_bbox, right_bbox = a["bbox"], b["bbox"]
else:
left_bbox, right_bbox = b["bbox"], a["bbox"]
ly1, ly2, lx1, lx2 = left_bbox
ry1, ry2, rx1, rx2 = right_bbox
left_arr = arr_u8[ly1:ly2, lx1:lx2]
right_arr = arr_u8[ry1:ry2, rx1:rx2]
else:
# Fallback: split the single bbox in half
method = "fallback_center_split"
W = arr_u8.shape[1]
split = W // 2
left_arr = arr_u8[:, :split]
right_arr = arr_u8[:, split:]
# Convert to float and canonicalize per digit
left_tile = canonicalize_digit_28x28((left_arr.astype(np.float32) / 255.0) if left_arr is not None else np.zeros((1, 1), dtype=np.float32))
right_tile = canonicalize_digit_28x28((right_arr.astype(np.float32) / 255.0) if right_arr is not None else np.zeros((1, 1), dtype=np.float32))
composed = np.concatenate([left_tile, right_tile], axis=1)
diag = {
"segmentation": {
"method": method,
"canvas_crop": {"top": int(y1), "bottom": int(y2), "left": int(x1), "right": int(x2)},
}
}
return composed.astype(np.float32, copy=False), diag
def preprocess_composed_28x56(arr_28x56: np.ndarray, stroke_scale: float = 1.0, *, extra_diag: dict | None = None):
ensure_model_loaded()
if arr_28x56 is None:
return None
arr_resized = np.clip(arr_28x56.astype(np.float32), 0.0, 1.0)
mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH)
std_safe = np.maximum(std, STD_FLOOR)
stroke_scale = float(stroke_scale)
stroke_scale = max(0.3, min(stroke_scale, 1.5))
arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
auto_balance_scale = 1.0
pre_balance_mass_fraction = float(arr_resized.mean())
target_mass = float(mean.mean())
arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
arr_resized,
target_mass_fraction=target_mass,
clamp=(0.6, 1.6),
)
# We already centered per 28x28 tile; skip whole-image recentering here
arr_centered = arr_resized
augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)]
augmented_standardized = []
for arr in augmented_arrays:
z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
z = np.clip(z, -8.0, 8.0)
augmented_standardized.append(z.astype(np.float32, copy=False))
mean_diff = np.abs(arr_centered - mean_image)
mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
diagnostics = compute_diagnostics(
arr_centered,
None,
arr_centered.shape,
mean_image,
augmented_standardized[0],
std_safe,
)
diagnostics["applied_auto_balance"] = {
"enabled": True,
"scale": float(auto_balance_scale),
"mass_fraction_after": float(balanced_mass_fraction),
"mass_fraction_before": float(pre_balance_mass_fraction),
"target_mass_fraction": float(target_mass),
}
if extra_diag:
diagnostics.update(extra_diag)
return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
def preprocess_image(img_input, stroke_scale: float = 1.0):
ensure_model_loaded()
img = extract_canvas_array(img_input)
if img is None:
return None
try:
bands = img.getbands()
except Exception:
bands = ()
if "A" in bands:
rgba = img.convert("RGBA")
white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255))
img = Image.alpha_composite(white_bg, rgba).convert("RGB")
img = img.convert("L")
img = ImageOps.invert(img)
arr_u8 = np.array(img, dtype=np.uint8)
original_canvas_shape = arr_u8.shape
coords = np.column_stack(np.where(arr_u8 > 10))
bbox = None
if coords.size > 0:
y_min, x_min = coords.min(axis=0)
y_max, x_max = coords.max(axis=0) + 1
pad = 4
y_min = max(0, y_min - pad)
x_min = max(0, x_min - pad)
y_max = min(arr_u8.shape[0], y_max + pad)
x_max = min(arr_u8.shape[1], x_max + pad)
bbox = (int(y_min), int(y_max), int(x_min), int(x_max))
arr_u8 = arr_u8[y_min:y_max, x_min:x_max]
if arr_u8.size == 0:
return None
h, w = arr_u8.shape
target_ratio = TARGET_WIDTH / TARGET_HEIGHT
if h == 0 or w == 0:
return None
current_ratio = w / h if h else target_ratio
if current_ratio > target_ratio:
new_height = int(round(w / target_ratio))
pad_total = max(new_height - h, 0)
pad_top = pad_total // 2
pad_bottom = pad_total - pad_top
pad_left = pad_right = 0
else:
new_width = int(round(h * target_ratio))
pad_total = max(new_width - w, 0)
pad_left = pad_total // 2
pad_right = pad_total - pad_left
pad_top = pad_bottom = 0
arr_padded = np.pad(
arr_u8,
((pad_top, pad_bottom), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
resized = Image.fromarray(arr_padded).resize(
(TARGET_WIDTH, TARGET_HEIGHT), Image.Resampling.LANCZOS
)
arr_resized = np.array(resized, dtype=np.float32) / 255.0
mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH)
std_safe = np.maximum(std, STD_FLOOR)
stroke_scale = float(stroke_scale)
stroke_scale = max(0.3, min(stroke_scale, 1.5))
arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0)
auto_balance_scale = 1.0
# Match the dataset's global mean intensity (more faithful than a fixed midpoint)
pre_balance_mass_fraction = float(arr_resized.mean())
target_mass = float(mean.mean())
arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke(
arr_resized,
target_mass_fraction=target_mass,
clamp=(0.6, 1.6),
)
# Light recentering by center-of-mass to reduce sensitivity to placement
mass = arr_resized
total_intensity = float(mass.sum())
arr_centered = arr_resized
if total_intensity > 1e-6:
gy, gx = np.indices(mass.shape)
cy = float((gy * mass).sum() / total_intensity)
cx = float((gx * mass).sum() / total_intensity)
ideal_cy = (TARGET_HEIGHT - 1) / 2.0
ideal_cx = (TARGET_WIDTH - 1) / 2.0
dy = int(np.clip(round(ideal_cy - cy), -2, 2))
dx = int(np.clip(round(ideal_cx - cx), -2, 2))
if dy != 0 or dx != 0:
arr_centered = shift_with_zero_pad(arr_resized, dy, dx)
augmented_arrays = [arr_centered, *generate_inference_variants(arr_centered, fast=IS_SPACE)]
# Standardize each variant and clip to tame outliers for stable inference
augmented_standardized = []
for arr in augmented_arrays:
z = (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe
z = np.clip(z, -8.0, 8.0)
augmented_standardized.append(z.astype(np.float32, copy=False))
mean_diff = np.abs(arr_centered - mean_image)
mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8)
diagnostics = compute_diagnostics(
arr_centered,
bbox,
original_canvas_shape,
mean_image,
augmented_standardized[0],
std_safe,
)
diagnostics["applied_auto_balance"] = {
"enabled": True,
"scale": float(auto_balance_scale),
"mass_fraction_after": float(balanced_mass_fraction),
"mass_fraction_before": float(pre_balance_mass_fraction),
"target_mass_fraction": float(target_mass),
}
return augmented_standardized, arr_centered, mean_diff_uint8, diagnostics
def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe):
mass = arr_float
total_intensity = float(mass.sum())
mass_threshold = mass > 0.05
if mass_threshold.any():
ys, xs = np.where(mass_threshold)
bbox_est = (int(ys.min()), int(ys.max()) + 1, int(xs.min()), int(xs.max()) + 1)
else:
bbox_est = None
cy = cx = None
if total_intensity > 1e-6:
grid_y, grid_x = np.indices(mass.shape)
weighted_sum = mass.sum()
cy = float((grid_y * mass).sum() / weighted_sum)
cx = float((grid_x * mass).sum() / weighted_sum)
bbox_use = bbox_est or bbox
if bbox_use:
top, bottom, left, right = bbox_use
height = bottom - top
width = right - left
bbox_area = height * width
bbox_metrics = {
"top": top,
"bottom": bottom,
"left": left,
"right": right,
"height": height,
"width": width,
"aspect_ratio": float(width / height) if height else None,
"area": bbox_area,
"area_ratio": float(bbox_area / (TARGET_HEIGHT * TARGET_WIDTH)) if bbox_area else 0.0,
}
else:
bbox_metrics = {
"top": None,
"bottom": None,
"left": None,
"right": None,
"height": 0,
"width": 0,
"aspect_ratio": None,
"area": 0,
"area_ratio": 0.0,
}
density = 0.0
bbox_area = bbox_metrics.get("area", 0)
if bbox_area:
density = float(total_intensity / bbox_area)
center_offset = None
if cy is not None and cx is not None:
ideal_cy = (TARGET_HEIGHT - 1) / 2.0
ideal_cx = (TARGET_WIDTH - 1) / 2.0
center_offset = float(np.sqrt((cy - ideal_cy) ** 2 + (cx - ideal_cx) ** 2))
standardized_flat = standardized.flatten()
mean_flat = mean_image.flatten()
arr_flat = arr_float.flatten()
std_flat = std_safe.flatten()
norm_input = np.linalg.norm(arr_flat)
norm_mean = np.linalg.norm(mean_flat)
cosine_similarity = None
if norm_input > 0.0 and norm_mean > 0.0:
cosine_similarity = float(np.dot(arr_flat, mean_flat) / (norm_input * norm_mean))
mean_abs_z = float(np.mean(np.abs(standardized_flat)))
max_abs_z = float(np.max(np.abs(standardized_flat)))
std_of_z = float(np.std(standardized_flat))
low_var_mask = std_flat <= STD_FLOOR + 1e-12
activated_low_var = int(np.count_nonzero(low_var_mask & (np.abs(arr_flat - mean_flat) > 1e-3)))
stats = {
"total_intensity": total_intensity,
"mass_fraction": float(total_intensity / (TARGET_HEIGHT * TARGET_WIDTH)),
"center_of_mass": {"row": cy, "col": cx},
"center_offset": center_offset,
"bbox": bbox_metrics,
"original_canvas_shape": original_shape,
"stroke_density": density,
"warnings": [],
"mean_intensity": float(arr_float.mean()),
"pixel_intensity_range": {
"min": float(arr_float.min()),
"max": float(arr_float.max()),
},
"cosine_similarity_vs_mean": cosine_similarity,
"mean_abs_z_score": mean_abs_z,
"max_abs_z_score": max_abs_z,
"std_abs_z_score": std_of_z,
"low_variance_pixels_triggered": activated_low_var,
"low_variance_threshold": STD_FLOOR,
"low_variance_pixels_fraction": float(activated_low_var / max(1, int(low_var_mask.sum()))),
}
if mean_image is not None:
stats["distance_from_mean"] = float(np.linalg.norm(arr_float - mean_image))
metric_checks = {}
for metric_name in (
"mass_fraction",
"stroke_density",
"center_offset",
"mean_abs_z_score",
"max_abs_z_score",
"std_abs_z_score",
):
value = stats.get(metric_name)
if value is not None:
value = float(value)
status, target_dict = _metric_status(metric_name, value)
entry = {"value": value, "status": status}
if target_dict is not None:
entry["target"] = target_dict
metric_checks[metric_name] = entry
stats["metric_checks"] = metric_checks
return stats
def enrich_diagnostics(stats, probs):
warnings = []
bbox = stats.get("bbox", {})
metric_checks = stats.get("metric_checks", {})
for name, info in metric_checks.items():
if info.get("status") == "out_of_range":
target = info.get("target")
value = info.get("value")
value_str = "None" if value is None else f"{value:.4f}"
if target is not None:
warnings.append(
f"{name}: value={value_str}, target=[{target['min']:.4f},{target['max']:.4f}]"
)
else:
warnings.append(f"{name}: value={value_str}")
aspect_ratio = bbox.get("aspect_ratio")
if aspect_ratio is not None and (aspect_ratio < 1.0 or aspect_ratio > 3.5):
warnings.append(f"aspect_ratio: value={aspect_ratio:.4f}, expected≈[1.00,3.50]")
confidences = np.sort(probs.flatten())[::-1]
if confidences.size >= 2:
margin = confidences[0] - confidences[1]
stats_margin = {
"value": float(margin),
"status": "ok" if margin >= 0.05 else "low_margin",
"target": {"min": 0.05, "max": 1.0},
}
else:
margin = None
stats_margin = {"value": None, "status": "insufficient_classes"}
if margin is not None and margin < 0.05:
warnings.append(f"prob_margin: value={margin:.4f}, target≥0.0500")
stats = dict(stats)
stats["warnings"] = warnings
stats["top_confidence"] = float(confidences[0]) if confidences.size else None
stats["second_confidence"] = float(confidences[1]) if confidences.size > 1 else None
stats["prob_margin"] = stats_margin
return stats
def predict_number(main_canvas):
ensure_model_loaded()
composed, seg_diag = compose_from_single_canvas(main_canvas)
if composed is None:
blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
diagnostics = {"warnings": ["Draw two digits to see diagnostics."]}
return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
result = preprocess_composed_28x56(
composed,
extra_diag=seg_diag,
)
if result is None:
blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)}
empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8)
diagnostics = {"warnings": ["Draw two digits to see diagnostics."]}
return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2)
standardized_variants, preview, mean_diff, diagnostics = result
variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False)
cache, probs_matrix = forward_prop(variants_matrix, params, training=False)
# Average probabilities across variants to reduce domination by any single variant
probs = np.mean(probs_matrix, axis=1, keepdims=True)
pred = int(get_predictions(probs)[0])
prob_rows = [[f"{i:02d}", float(probs[i, 0])] for i in range(OUTPUT_CLASSES)]
prob_rows.sort(key=lambda r: r[1], reverse=True)
diagnostics = enrich_diagnostics(diagnostics, probs)
diagnostics["variants_used"] = int(probs_matrix.shape[1])
diagnostics["variant_top_confidences"] = [
float(probs_matrix[pred, idx]) for idx in range(probs_matrix.shape[1])
]
return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2)
with gr.Blocks() as demo:
gr.Markdown(
"""
# Elliot's MNIST-100 Classifier
Draw a two-digit number (00-99) on the single canvas. The app uses the Convolutional Neural Network to predict the number accurately.
"""
)
with gr.Row():
with gr.Column(scale=1) as left_col:
main_canvas = gr.Sketchpad(label="Draw Two Digits (00–99)")
with gr.Column(scale=1):
pred_box = gr.Number(label="Predicted Number", precision=0, value=None)
prob_table = gr.Dataframe(
label="Class Probabilities",
headers=["class", "prob"],
datatype=["str", "number"],
interactive=False,
)
preview = gr.Image(label="Model Input Preview (28x56)", image_mode="L")
mean_diff_view = gr.Image(label="Difference vs Training Mean", image_mode="L")
diagnostics_box = gr.Code(label="Diagnostics (JSON)", language="json")
# Place buttons under the canvas, but wire them to clear outputs as well
with left_col:
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
clear_btn = gr.ClearButton(
[
main_canvas,
pred_box,
prob_table,
preview,
mean_diff_view,
diagnostics_box,
]
)
predict_btn.click(
fn=predict_number,
inputs=[main_canvas],
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
)
# On Spaces, avoid per-stroke inference to prevent event floods
if not IS_SPACE:
main_canvas.change(
fn=predict_number,
inputs=[main_canvas],
outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box],
)
if __name__ == "__main__":
space_env = os.getenv("SPACE_ID")
def _queue_app(blocks):
try:
return blocks.queue(concurrency_count=1)
except TypeError:
# Older Gradio versions don't support the argument
try:
return blocks.queue()
except Exception:
return blocks
app_to_launch = _queue_app(demo)
if space_env:
app_to_launch.launch(show_api=False)
else:
app_to_launch.launch(server_name="0.0.0.0", share=True, show_api=False)
def _disable_gradio_api_schema(*_args, **_kwargs):
"""Work around Gradio schema bug on Python 3.13 by returning empty metadata."""
return {}
gr_routes.api_info = _disable_gradio_api_schema