Scribbler310's picture
Update app.py
79f3351 verified
import tensorflow as tf
import gradio as gr
import cv2
import numpy as np
import math
# 1. Load the model
model = tf.keras.models.load_model('digit_recognizer.keras')
def preprocess_mnist_style(image):
"""
Converts a user drawing into the strict format expected by MNIST models:
- Invert colors (if needed) to get white digit on black background
- Crop to bounding box (remove empty margins)
- Resize digit to max 20x20 while preserving aspect ratio
- Center digit by center-of-mass in a 28x28 image
"""
# 1. Convert to grayscale if needed
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 2. Invert if background is light (MNIST is white-on-black)
if np.mean(image) > 127:
image = 255 - image
# 3. Find the bounding box of the digit (crop empty space)
# Find all non-zero points (pixels that are part of the digit)
coords = cv2.findNonZero(image)
if coords is None:
return image # Return original if empty
x, y, w, h = cv2.boundingRect(coords)
# Crop the digit
digit = image[y:y+h, x:x+w]
# 4. Resize to fit inside a 20x20 box (leaving 4px buffer)
# MNIST digits are 20x20 centered in 28x28
rows, cols = digit.shape
if rows > cols:
factor = 20.0 / rows
rows = 20
cols = int(round(cols * factor))
else:
factor = 20.0 / cols
cols = 20
rows = int(round(rows * factor))
# Resize using INTER_AREA for better quality downscaling
digit = cv2.resize(digit, (cols, rows), interpolation=cv2.INTER_AREA)
# 5. Paste the resized digit into the center of a black 28x28 canvas
new_image = np.zeros((28, 28), dtype=np.uint8)
# Calculate center offset
pad_x = (28 - cols) // 2
pad_y = (28 - rows) // 2
new_image[pad_y:pad_y+rows, pad_x:pad_x+cols] = digit
# 6. Center by "Center of Mass" (Refinement step used in original MNIST)
# Calculate moments to find the weighted center
moments = cv2.moments(new_image)
if moments['m00'] > 0:
cx = moments['m10'] / moments['m00']
cy = moments['m01'] / moments['m00']
# Shift to align center of mass (cx, cy) to image center (14, 14)
shift_x = 14 - cx
shift_y = 14 - cy
M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
new_image = cv2.warpAffine(new_image, M, (28, 28))
return new_image
def classify_digit(image):
if image is None:
return None
# Handle Gradio 4.x dictionary input
if isinstance(image, dict):
image = image['composite']
image = np.array(image)
# --- INPUT HANDLING ---
# Handle RGBA (Transparent)
if image.shape[-1] == 4:
# Create white background
background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
alpha = image[:, :, 3] / 255.0
for c in range(3):
background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
image = background
# --- APPLY ROBUST PREPROCESSING ---
processed_image = preprocess_mnist_style(image)
# Normalize (0 to 1) and Reshape
final_input = processed_image.reshape(1, 28, 28, 1) / 255.0
# --- PREDICTION ---
prediction = model.predict(final_input).flatten()
return {str(i): float(prediction[i]) for i in range(10)}
# --- UI SETUP ---
with gr.Blocks() as demo:
gr.Markdown("## Handwritten Digit Recognizer")
gr.Markdown("Draw a digit (0-9) or upload a photo. The robust preprocessing now centers and scales your input like real MNIST data.")
with gr.Tabs():
with gr.Tab("Draw Digit"):
sketchpad = gr.Sketchpad(
label="Draw Here",
type="numpy",
brush=gr.Brush(colors=["#000000"], default_size=20)
)
btn_draw = gr.Button("Predict Drawing", variant="primary")
with gr.Tab("Upload Photo"):
upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
btn_upload = gr.Button("Predict Upload", variant="primary")
label = gr.Label(num_top_classes=3, label="Prediction")
btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
if __name__ == "__main__":
demo.launch()