Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import math | |
| # 1. Load the model | |
| model = tf.keras.models.load_model('digit_recognizer.keras') | |
| def preprocess_mnist_style(image): | |
| """ | |
| Converts a user drawing into the strict format expected by MNIST models: | |
| - Invert colors (if needed) to get white digit on black background | |
| - Crop to bounding box (remove empty margins) | |
| - Resize digit to max 20x20 while preserving aspect ratio | |
| - Center digit by center-of-mass in a 28x28 image | |
| """ | |
| # 1. Convert to grayscale if needed | |
| if len(image.shape) == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| # 2. Invert if background is light (MNIST is white-on-black) | |
| if np.mean(image) > 127: | |
| image = 255 - image | |
| # 3. Find the bounding box of the digit (crop empty space) | |
| # Find all non-zero points (pixels that are part of the digit) | |
| coords = cv2.findNonZero(image) | |
| if coords is None: | |
| return image # Return original if empty | |
| x, y, w, h = cv2.boundingRect(coords) | |
| # Crop the digit | |
| digit = image[y:y+h, x:x+w] | |
| # 4. Resize to fit inside a 20x20 box (leaving 4px buffer) | |
| # MNIST digits are 20x20 centered in 28x28 | |
| rows, cols = digit.shape | |
| if rows > cols: | |
| factor = 20.0 / rows | |
| rows = 20 | |
| cols = int(round(cols * factor)) | |
| else: | |
| factor = 20.0 / cols | |
| cols = 20 | |
| rows = int(round(rows * factor)) | |
| # Resize using INTER_AREA for better quality downscaling | |
| digit = cv2.resize(digit, (cols, rows), interpolation=cv2.INTER_AREA) | |
| # 5. Paste the resized digit into the center of a black 28x28 canvas | |
| new_image = np.zeros((28, 28), dtype=np.uint8) | |
| # Calculate center offset | |
| pad_x = (28 - cols) // 2 | |
| pad_y = (28 - rows) // 2 | |
| new_image[pad_y:pad_y+rows, pad_x:pad_x+cols] = digit | |
| # 6. Center by "Center of Mass" (Refinement step used in original MNIST) | |
| # Calculate moments to find the weighted center | |
| moments = cv2.moments(new_image) | |
| if moments['m00'] > 0: | |
| cx = moments['m10'] / moments['m00'] | |
| cy = moments['m01'] / moments['m00'] | |
| # Shift to align center of mass (cx, cy) to image center (14, 14) | |
| shift_x = 14 - cx | |
| shift_y = 14 - cy | |
| M = np.float32([[1, 0, shift_x], [0, 1, shift_y]]) | |
| new_image = cv2.warpAffine(new_image, M, (28, 28)) | |
| return new_image | |
| def classify_digit(image): | |
| if image is None: | |
| return None | |
| # Handle Gradio 4.x dictionary input | |
| if isinstance(image, dict): | |
| image = image['composite'] | |
| image = np.array(image) | |
| # --- INPUT HANDLING --- | |
| # Handle RGBA (Transparent) | |
| if image.shape[-1] == 4: | |
| # Create white background | |
| background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255 | |
| alpha = image[:, :, 3] / 255.0 | |
| for c in range(3): | |
| background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c] | |
| image = background | |
| # --- APPLY ROBUST PREPROCESSING --- | |
| processed_image = preprocess_mnist_style(image) | |
| # Normalize (0 to 1) and Reshape | |
| final_input = processed_image.reshape(1, 28, 28, 1) / 255.0 | |
| # --- PREDICTION --- | |
| prediction = model.predict(final_input).flatten() | |
| return {str(i): float(prediction[i]) for i in range(10)} | |
| # --- UI SETUP --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Handwritten Digit Recognizer") | |
| gr.Markdown("Draw a digit (0-9) or upload a photo. The robust preprocessing now centers and scales your input like real MNIST data.") | |
| with gr.Tabs(): | |
| with gr.Tab("Draw Digit"): | |
| sketchpad = gr.Sketchpad( | |
| label="Draw Here", | |
| type="numpy", | |
| brush=gr.Brush(colors=["#000000"], default_size=20) | |
| ) | |
| btn_draw = gr.Button("Predict Drawing", variant="primary") | |
| with gr.Tab("Upload Photo"): | |
| upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy") | |
| btn_upload = gr.Button("Predict Upload", variant="primary") | |
| label = gr.Label(num_top_classes=3, label="Prediction") | |
| btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label) | |
| btn_upload.click(fn=classify_digit, inputs=upload, outputs=label) | |
| if __name__ == "__main__": | |
| demo.launch() |