|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
model = tf.keras.models.load_model("mnist.h5") |
|
|
|
|
|
|
|
|
|
|
|
def predict_digit(img): |
|
|
|
|
|
img = img.reshape((1, 28, 28, 1)) |
|
|
|
|
|
if img.max() > 1: |
|
|
img = img / 255.0 |
|
|
|
|
|
res = model.predict([img])[0] |
|
|
|
|
|
return {str(i): float(res[i]) for i in range(10)} |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
css="style.css", |
|
|
theme=gr.themes.Default(primary_hue="blue", secondary_hue="cyan"), |
|
|
) as app: |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
"""# MNIST Digit Recognizer |
|
|
This app recognizes handwritten digits. The app uses a sketchpad to get the input image. |
|
|
|
|
|
Model used is a two layered Convolution network, followed by a fully connected layer and a softmax layer. |
|
|
""", |
|
|
) |
|
|
with gr.Row(): |
|
|
gr.Markdown("## Sketchpad") |
|
|
gr.Markdown("## Prediction") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
sketchpad = gr.Sketchpad( |
|
|
shape=(28, 28), |
|
|
brush_radius=2, |
|
|
elem_id="sketchpad", |
|
|
label="Draw a digit here", |
|
|
) |
|
|
blank_sketchpad = gr.Sketchpad( |
|
|
invert_colors=True, brush_radius=2, visible=False |
|
|
) |
|
|
|
|
|
label = gr.Label( |
|
|
num_top_classes=3, |
|
|
elem_id="label", |
|
|
label="Prediction", |
|
|
) |
|
|
|
|
|
|
|
|
button = gr.Button("Predict", elem_id="btn_pred") |
|
|
|
|
|
button.click( |
|
|
predict_digit, |
|
|
inputs=sketchpad, |
|
|
outputs=label, |
|
|
) |
|
|
|
|
|
|
|
|
clear_button = gr.Button("Clear", elem_id="btn_clr") |
|
|
clear_button.click(lambda a: None, inputs=blank_sketchpad, outputs=sketchpad) |
|
|
|
|
|
app.launch(share=False) |
|
|
|