Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| # 1. Load the trained model | |
| model = tf.keras.models.load_model('digit_recognizer.keras') | |
| # 2. Define the classification function | |
| def classify_digit(image): | |
| if image is None: | |
| return None | |
| # Preprocessing to match MNIST data format | |
| # Convert to grayscale if it isn't already | |
| if len(image.shape) == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| # Resize the image to 28x28 pixels | |
| image = cv2.resize(image, (28, 28)) | |
| # Reshape to (1, 28, 28, 1) to match model input shape | |
| # The '1' indicates a batch size of 1 | |
| image = image.reshape(1, 28, 28, 1) | |
| # Normalize pixel values (0 to 1) just like in the training notebook | |
| image = image / 255.0 | |
| # Predict | |
| prediction = model.predict(image).flatten() | |
| # Return dictionary for Gradio Label output | |
| return {str(i): float(prediction[i]) for i in range(10)} | |
| # 3. Build the Gradio Interface | |
| # We use Sketchpad so users can draw the digit | |
| interface = gr.Interface( | |
| fn=classify_digit, | |
| inputs=gr.Sketchpad(label="Draw a Digit"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Handwritten Digit Recognizer", | |
| description="Draw a digit (0-9) on the canvas to see if the Neural Network recognizes it." | |
| ) | |
| # 4. Launch | |
| if __name__ == "__main__": | |
| interface.launch() |