| import tensorflow as tf | |
| import numpy as np | |
| from urllib.request import urlretrieve | |
| import gradio as gr | |
| import numpy as np | |
| urlretrieve("https://huggingface.co/guiwitz/mnist2023/resolve/main/mnist_model.keras", "mnist_model.keras") | |
| model = tf.keras.models.load_model("mnist_model.keras") | |
| def recognize_digit(image): | |
| image = image[np.newaxis,:,:, np.newaxis] | |
| prediction = model.predict(image).tolist()[0] | |
| return {str(i): prediction[i] for i in range(10)} | |
| gr.Interface(fn=recognize_digit, | |
| inputs="sketchpad", | |
| outputs=gr.Label(num_top_classes=3), | |
| live=True, | |
| description="Live MNIST.").launch(); |