Scribbler310's picture
Update app.py
99f7345 verified
raw
history blame
1.4 kB
import tensorflow as tf
import gradio as gr
import numpy as np
import cv2
# 1. Load the trained model
model = tf.keras.models.load_model('digit_recognizer.keras')
# 2. Define the classification function
def classify_digit(image):
if image is None:
return None
# Preprocessing to match MNIST data format
# Convert to grayscale if it isn't already
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Resize the image to 28x28 pixels
image = cv2.resize(image, (28, 28))
# Reshape to (1, 28, 28, 1) to match model input shape
# The '1' indicates a batch size of 1
image = image.reshape(1, 28, 28, 1)
# Normalize pixel values (0 to 1) just like in the training notebook
image = image / 255.0
# Predict
prediction = model.predict(image).flatten()
# Return dictionary for Gradio Label output
return {str(i): float(prediction[i]) for i in range(10)}
# 3. Build the Gradio Interface
# We use Sketchpad so users can draw the digit
interface = gr.Interface(
fn=classify_digit,
inputs=gr.Sketchpad(label="Draw a Digit"),
outputs=gr.Label(num_top_classes=3),
title="Handwritten Digit Recognizer",
description="Draw a digit (0-9) on the canvas to see if the Neural Network recognizes it."
)
# 4. Launch
if __name__ == "__main__":
interface.launch()