import tensorflow as tf import gradio as gr import cv2 import numpy as np import math # 1. Load the model model = tf.keras.models.load_model('digit_recognizer.keras') def preprocess_mnist_style(image): """ Converts a user drawing into the strict format expected by MNIST models: - Invert colors (if needed) to get white digit on black background - Crop to bounding box (remove empty margins) - Resize digit to max 20x20 while preserving aspect ratio - Center digit by center-of-mass in a 28x28 image """ # 1. Convert to grayscale if needed if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # 2. Invert if background is light (MNIST is white-on-black) if np.mean(image) > 127: image = 255 - image # 3. Find the bounding box of the digit (crop empty space) # Find all non-zero points (pixels that are part of the digit) coords = cv2.findNonZero(image) if coords is None: return image # Return original if empty x, y, w, h = cv2.boundingRect(coords) # Crop the digit digit = image[y:y+h, x:x+w] # 4. Resize to fit inside a 20x20 box (leaving 4px buffer) # MNIST digits are 20x20 centered in 28x28 rows, cols = digit.shape if rows > cols: factor = 20.0 / rows rows = 20 cols = int(round(cols * factor)) else: factor = 20.0 / cols cols = 20 rows = int(round(rows * factor)) # Resize using INTER_AREA for better quality downscaling digit = cv2.resize(digit, (cols, rows), interpolation=cv2.INTER_AREA) # 5. Paste the resized digit into the center of a black 28x28 canvas new_image = np.zeros((28, 28), dtype=np.uint8) # Calculate center offset pad_x = (28 - cols) // 2 pad_y = (28 - rows) // 2 new_image[pad_y:pad_y+rows, pad_x:pad_x+cols] = digit # 6. Center by "Center of Mass" (Refinement step used in original MNIST) # Calculate moments to find the weighted center moments = cv2.moments(new_image) if moments['m00'] > 0: cx = moments['m10'] / moments['m00'] cy = moments['m01'] / moments['m00'] # Shift to align center of mass (cx, cy) to image center (14, 14) shift_x = 14 - cx shift_y = 14 - cy M = np.float32([[1, 0, shift_x], [0, 1, shift_y]]) new_image = cv2.warpAffine(new_image, M, (28, 28)) return new_image def classify_digit(image): if image is None: return None # Handle Gradio 4.x dictionary input if isinstance(image, dict): image = image['composite'] image = np.array(image) # --- INPUT HANDLING --- # Handle RGBA (Transparent) if image.shape[-1] == 4: # Create white background background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255 alpha = image[:, :, 3] / 255.0 for c in range(3): background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c] image = background # --- APPLY ROBUST PREPROCESSING --- processed_image = preprocess_mnist_style(image) # Normalize (0 to 1) and Reshape final_input = processed_image.reshape(1, 28, 28, 1) / 255.0 # --- PREDICTION --- prediction = model.predict(final_input).flatten() return {str(i): float(prediction[i]) for i in range(10)} # --- UI SETUP --- with gr.Blocks() as demo: gr.Markdown("## Handwritten Digit Recognizer") gr.Markdown("Draw a digit (0-9) or upload a photo. The robust preprocessing now centers and scales your input like real MNIST data.") with gr.Tabs(): with gr.Tab("Draw Digit"): sketchpad = gr.Sketchpad( label="Draw Here", type="numpy", brush=gr.Brush(colors=["#000000"], default_size=20) ) btn_draw = gr.Button("Predict Drawing", variant="primary") with gr.Tab("Upload Photo"): upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy") btn_upload = gr.Button("Predict Upload", variant="primary") label = gr.Label(num_top_classes=3, label="Prediction") btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label) btn_upload.click(fn=classify_digit, inputs=upload, outputs=label) if __name__ == "__main__": demo.launch()