Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,58 +2,114 @@ import tensorflow as tf
|
|
| 2 |
import gradio as gr
|
| 3 |
import cv2
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
|
| 6 |
# 1. Load the model
|
| 7 |
model = tf.keras.models.load_model('digit_recognizer.keras')
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def classify_digit(image):
|
| 10 |
if image is None:
|
| 11 |
return None
|
| 12 |
|
| 13 |
-
#
|
| 14 |
if isinstance(image, dict):
|
| 15 |
image = image['composite']
|
| 16 |
-
|
| 17 |
-
image = np.array(image)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
if image.shape[-1] == 4:
|
| 22 |
-
#
|
| 23 |
background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
|
| 24 |
alpha = image[:, :, 3] / 255.0
|
| 25 |
for c in range(3):
|
| 26 |
background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
|
| 27 |
-
image =
|
| 28 |
-
elif len(image.shape) == 3 and image.shape[-1] == 3:
|
| 29 |
-
# RGB: Convert to Gray
|
| 30 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 31 |
-
|
| 32 |
-
# 2. Resize to 28x28 (Model Requirement)
|
| 33 |
-
image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
|
| 34 |
-
|
| 35 |
-
# 3. Invert Colors (Critical)
|
| 36 |
-
# MNIST expects white digit on black background.
|
| 37 |
-
# If image is mostly bright (white paper/canvas), invert it.
|
| 38 |
-
if np.mean(image) > 127:
|
| 39 |
-
image = 255 - image
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# --- PREDICTION ---
|
| 45 |
-
prediction = model.predict(
|
| 46 |
return {str(i): float(prediction[i]) for i in range(10)}
|
| 47 |
|
| 48 |
# --- UI SETUP ---
|
| 49 |
with gr.Blocks() as demo:
|
| 50 |
gr.Markdown("## Handwritten Digit Recognizer")
|
| 51 |
-
gr.Markdown("Draw a digit (0-9) or upload a photo
|
| 52 |
|
| 53 |
with gr.Tabs():
|
| 54 |
-
# Tab 1: Drawing Interface
|
| 55 |
with gr.Tab("Draw Digit"):
|
| 56 |
-
# FIX: Use 'default_size' and 'colors' instead of 'thickness' and 'color'
|
| 57 |
sketchpad = gr.Sketchpad(
|
| 58 |
label="Draw Here",
|
| 59 |
type="numpy",
|
|
@@ -61,16 +117,12 @@ with gr.Blocks() as demo:
|
|
| 61 |
)
|
| 62 |
btn_draw = gr.Button("Predict Drawing", variant="primary")
|
| 63 |
|
| 64 |
-
# Tab 2: Upload Interface
|
| 65 |
with gr.Tab("Upload Photo"):
|
| 66 |
-
# FIX: Use 'sources=["upload", "clipboard"]' to avoid the source list error
|
| 67 |
upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
|
| 68 |
btn_upload = gr.Button("Predict Upload", variant="primary")
|
| 69 |
|
| 70 |
-
# Output is shared
|
| 71 |
label = gr.Label(num_top_classes=3, label="Prediction")
|
| 72 |
|
| 73 |
-
# Connect both buttons to the same function
|
| 74 |
btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
|
| 75 |
btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
|
| 76 |
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import cv2
|
| 4 |
import numpy as np
|
| 5 |
+
import math
|
| 6 |
|
| 7 |
# 1. Load the model
|
| 8 |
model = tf.keras.models.load_model('digit_recognizer.keras')
|
| 9 |
|
| 10 |
+
def preprocess_mnist_style(image):
|
| 11 |
+
"""
|
| 12 |
+
Converts a user drawing into the strict format expected by MNIST models:
|
| 13 |
+
- Invert colors (if needed) to get white digit on black background
|
| 14 |
+
- Crop to bounding box (remove empty margins)
|
| 15 |
+
- Resize digit to max 20x20 while preserving aspect ratio
|
| 16 |
+
- Center digit by center-of-mass in a 28x28 image
|
| 17 |
+
"""
|
| 18 |
+
# 1. Convert to grayscale if needed
|
| 19 |
+
if len(image.shape) == 3:
|
| 20 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 21 |
+
|
| 22 |
+
# 2. Invert if background is light (MNIST is white-on-black)
|
| 23 |
+
if np.mean(image) > 127:
|
| 24 |
+
image = 255 - image
|
| 25 |
+
|
| 26 |
+
# 3. Find the bounding box of the digit (crop empty space)
|
| 27 |
+
# Find all non-zero points (pixels that are part of the digit)
|
| 28 |
+
coords = cv2.findNonZero(image)
|
| 29 |
+
if coords is None:
|
| 30 |
+
return image # Return original if empty
|
| 31 |
+
|
| 32 |
+
x, y, w, h = cv2.boundingRect(coords)
|
| 33 |
+
# Crop the digit
|
| 34 |
+
digit = image[y:y+h, x:x+w]
|
| 35 |
+
|
| 36 |
+
# 4. Resize to fit inside a 20x20 box (leaving 4px buffer)
|
| 37 |
+
# MNIST digits are 20x20 centered in 28x28
|
| 38 |
+
rows, cols = digit.shape
|
| 39 |
+
if rows > cols:
|
| 40 |
+
factor = 20.0 / rows
|
| 41 |
+
rows = 20
|
| 42 |
+
cols = int(round(cols * factor))
|
| 43 |
+
else:
|
| 44 |
+
factor = 20.0 / cols
|
| 45 |
+
cols = 20
|
| 46 |
+
rows = int(round(rows * factor))
|
| 47 |
+
|
| 48 |
+
# Resize using INTER_AREA for better quality downscaling
|
| 49 |
+
digit = cv2.resize(digit, (cols, rows), interpolation=cv2.INTER_AREA)
|
| 50 |
+
|
| 51 |
+
# 5. Paste the resized digit into the center of a black 28x28 canvas
|
| 52 |
+
new_image = np.zeros((28, 28), dtype=np.uint8)
|
| 53 |
+
|
| 54 |
+
# Calculate center offset
|
| 55 |
+
pad_x = (28 - cols) // 2
|
| 56 |
+
pad_y = (28 - rows) // 2
|
| 57 |
+
|
| 58 |
+
new_image[pad_y:pad_y+rows, pad_x:pad_x+cols] = digit
|
| 59 |
+
|
| 60 |
+
# 6. Center by "Center of Mass" (Refinement step used in original MNIST)
|
| 61 |
+
# Calculate moments to find the weighted center
|
| 62 |
+
moments = cv2.moments(new_image)
|
| 63 |
+
if moments['m00'] > 0:
|
| 64 |
+
cx = moments['m10'] / moments['m00']
|
| 65 |
+
cy = moments['m01'] / moments['m00']
|
| 66 |
+
|
| 67 |
+
# Shift to align center of mass (cx, cy) to image center (14, 14)
|
| 68 |
+
shift_x = 14 - cx
|
| 69 |
+
shift_y = 14 - cy
|
| 70 |
+
|
| 71 |
+
M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
|
| 72 |
+
new_image = cv2.warpAffine(new_image, M, (28, 28))
|
| 73 |
+
|
| 74 |
+
return new_image
|
| 75 |
+
|
| 76 |
def classify_digit(image):
|
| 77 |
if image is None:
|
| 78 |
return None
|
| 79 |
|
| 80 |
+
# Handle Gradio 4.x dictionary input
|
| 81 |
if isinstance(image, dict):
|
| 82 |
image = image['composite']
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
image = np.array(image)
|
| 85 |
+
|
| 86 |
+
# --- INPUT HANDLING ---
|
| 87 |
+
# Handle RGBA (Transparent)
|
| 88 |
if image.shape[-1] == 4:
|
| 89 |
+
# Create white background
|
| 90 |
background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
|
| 91 |
alpha = image[:, :, 3] / 255.0
|
| 92 |
for c in range(3):
|
| 93 |
background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
|
| 94 |
+
image = background
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
# --- APPLY ROBUST PREPROCESSING ---
|
| 97 |
+
processed_image = preprocess_mnist_style(image)
|
| 98 |
+
|
| 99 |
+
# Normalize (0 to 1) and Reshape
|
| 100 |
+
final_input = processed_image.reshape(1, 28, 28, 1) / 255.0
|
| 101 |
|
| 102 |
# --- PREDICTION ---
|
| 103 |
+
prediction = model.predict(final_input).flatten()
|
| 104 |
return {str(i): float(prediction[i]) for i in range(10)}
|
| 105 |
|
| 106 |
# --- UI SETUP ---
|
| 107 |
with gr.Blocks() as demo:
|
| 108 |
gr.Markdown("## Handwritten Digit Recognizer")
|
| 109 |
+
gr.Markdown("Draw a digit (0-9) or upload a photo. The robust preprocessing now centers and scales your input like real MNIST data.")
|
| 110 |
|
| 111 |
with gr.Tabs():
|
|
|
|
| 112 |
with gr.Tab("Draw Digit"):
|
|
|
|
| 113 |
sketchpad = gr.Sketchpad(
|
| 114 |
label="Draw Here",
|
| 115 |
type="numpy",
|
|
|
|
| 117 |
)
|
| 118 |
btn_draw = gr.Button("Predict Drawing", variant="primary")
|
| 119 |
|
|
|
|
| 120 |
with gr.Tab("Upload Photo"):
|
|
|
|
| 121 |
upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
|
| 122 |
btn_upload = gr.Button("Predict Upload", variant="primary")
|
| 123 |
|
|
|
|
| 124 |
label = gr.Label(num_top_classes=3, label="Prediction")
|
| 125 |
|
|
|
|
| 126 |
btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
|
| 127 |
btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
|
| 128 |
|