Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| # Load the model | |
| model = tf.keras.models.load_model("mnist.h5") | |
| # Define a function to predict the digit | |
| def predict_digit(img): | |
| # Resize the image to 28x28 pixels | |
| img = img.reshape((1, 28, 28, 1)) | |
| # Normalize the image | |
| if img.max() > 1: | |
| img = img / 255.0 | |
| # Predict the class | |
| res = model.predict([img])[0] | |
| # create result dictionary | |
| return {str(i): float(res[i]) for i in range(10)} | |
| with gr.Blocks( | |
| css="style.css", | |
| theme=gr.themes.Default(primary_hue="blue", secondary_hue="cyan"), | |
| ) as app: | |
| # create a header and a para | |
| with gr.Row(): | |
| gr.Markdown( | |
| """# MNIST Digit Recognizer | |
| This app recognizes handwritten digits. The app uses a sketchpad to get the input image. | |
| Model used is a two layered Convolution network, followed by a fully connected layer and a softmax layer. | |
| """, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("## Sketchpad") | |
| gr.Markdown("## Prediction") | |
| # create a row | |
| with gr.Row(): | |
| # create a sketchpad | |
| sketchpad = gr.Sketchpad( | |
| shape=(28, 28), | |
| brush_radius=2, | |
| elem_id="sketchpad", | |
| label="Draw a digit here", | |
| ) | |
| blank_sketchpad = gr.Sketchpad( | |
| invert_colors=True, brush_radius=2, visible=False | |
| ) | |
| # create a label | |
| label = gr.Label( | |
| num_top_classes=3, | |
| elem_id="label", | |
| label="Prediction", | |
| ) | |
| # create a button | |
| button = gr.Button("Predict", elem_id="btn_pred") | |
| # bind the button to predict the digit | |
| button.click( | |
| predict_digit, | |
| inputs=sketchpad, | |
| outputs=label, | |
| ) | |
| # create a clear button for sketchpad | |
| clear_button = gr.Button("Clear", elem_id="btn_clr") | |
| clear_button.click(lambda a: None, inputs=blank_sketchpad, outputs=sketchpad) | |
| app.launch(share=False) | |