834188divi's picture
Duplicate from hari31416/digit_recognizer
3b418e4
import gradio as gr
import numpy as np
import tensorflow as tf
# Load the model
model = tf.keras.models.load_model("mnist.h5")
# Define a function to predict the digit
def predict_digit(img):
# Resize the image to 28x28 pixels
img = img.reshape((1, 28, 28, 1))
# Normalize the image
if img.max() > 1:
img = img / 255.0
# Predict the class
res = model.predict([img])[0]
# create result dictionary
return {str(i): float(res[i]) for i in range(10)}
with gr.Blocks(
css="style.css",
theme=gr.themes.Default(primary_hue="blue", secondary_hue="cyan"),
) as app:
# create a header and a para
with gr.Row():
gr.Markdown(
"""# MNIST Digit Recognizer
This app recognizes handwritten digits. The app uses a sketchpad to get the input image.
Model used is a two layered Convolution network, followed by a fully connected layer and a softmax layer.
""",
)
with gr.Row():
gr.Markdown("## Sketchpad")
gr.Markdown("## Prediction")
# create a row
with gr.Row():
# create a sketchpad
sketchpad = gr.Sketchpad(
shape=(28, 28),
brush_radius=2,
elem_id="sketchpad",
label="Draw a digit here",
)
blank_sketchpad = gr.Sketchpad(
invert_colors=True, brush_radius=2, visible=False
)
# create a label
label = gr.Label(
num_top_classes=3,
elem_id="label",
label="Prediction",
)
# create a button
button = gr.Button("Predict", elem_id="btn_pred")
# bind the button to predict the digit
button.click(
predict_digit,
inputs=sketchpad,
outputs=label,
)
# create a clear button for sketchpad
clear_button = gr.Button("Clear", elem_id="btn_clr")
clear_button.click(lambda a: None, inputs=blank_sketchpad, outputs=sketchpad)
app.launch(share=False)