Spaces:
Sleeping
Sleeping
| """ | |
| A small Streamlit app that loads a Keras model trained on the MNIST dataset and allows the user to draw a digit on a canvas and get a predicted digit from the model. | |
| """ | |
| import streamlit as st | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| import os | |
| import numpy as np | |
| from keras import models | |
| import keras.datasets.mnist as mnist | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import time | |
| import onnx | |
| import onnxruntime | |
| from scipy.special import softmax | |
| def load_picture(): | |
| """ | |
| Loads the first 9 images from the mnist dataset and add them to a plot | |
| to be displayed in streamlit. | |
| """ | |
| # load the mnist dataset | |
| (x_train, y_train), (x_test, y_test) = mnist.load_data() | |
| # plot the first 9 images | |
| for i in range(9): | |
| plt.subplot(330 + 1 + i) | |
| image = x_train[i] / 255.0 | |
| plt.imshow(image, cmap=plt.get_cmap("gray")) | |
| # Save the plot as a png file and show it in streamlit | |
| # This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally | |
| # plt.savefig("img/show.png") | |
| st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset") | |
| def keras_prediction(final, model_path): | |
| load_time = time.time() | |
| model = models.load_model( | |
| os.path.abspath(os.path.join(os.path.dirname(__file__), model_path)) | |
| ) | |
| after_load_curr = time.time() | |
| curr_time = time.time() | |
| prediction = model.predict(final[None, ...]) | |
| after_time = time.time() | |
| return prediction, after_time - curr_time, after_load_curr - load_time | |
| def onnx_prediction(final, model_path): | |
| im_np = np.expand_dims(final, axis=0) # Add batch dimension | |
| im_np = np.expand_dims(im_np, axis=0) # Add channel dimension | |
| im_np = im_np.astype("float32") | |
| load_curr = time.time() | |
| session = onnxruntime.InferenceSession(model_path, None) | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| after_load_curr = time.time() | |
| curr_time = time.time() | |
| result = session.run([output_name], {input_name: im_np}) | |
| prediction = softmax(np.array(result).squeeze(), axis=0) | |
| after_time = time.time() | |
| return prediction, after_time - curr_time, after_load_curr - load_curr | |
| def main(): | |
| """ | |
| The main function/primary entry point of the app | |
| """ | |
| # write the title of the page as MNIST Digit Recognizer | |
| st.title("MNIST Digit Recognizer") | |
| col1, col2 = st.columns([0.8, 0.2], gap="small") | |
| with col1: | |
| st.markdown( | |
| """ | |
| This Streamlit app loads a Keras neural network trained on the MNIST dataset to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can: | |
| - Change the stroke width of the digit using the slider | |
| - Choose what model you use for predictions | |
| - Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a> | |
| - Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a> | |
| - Basic: A simple two layer nueral net where each layer has 300 nodes | |
| Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.""", | |
| unsafe_allow_html=True, | |
| ) | |
| with col2: | |
| # Load the first 9 images from the MNIST dataset and show them | |
| load_picture() | |
| col3, col4 = st.columns(2, gap="small") | |
| with col4: | |
| # Stroke width slider to change the width of the canvas stroke | |
| # Starts at 10 because that's reasonably close to the width of the MNIST digits | |
| stroke_width = st.slider("Stroke width: ", 1, 25, 10) | |
| model_choice = st.selectbox( | |
| "Choose what model to use for predictions:", ("Onnx", "Autokeras", "Basic") | |
| ) | |
| if "Basic" in model_choice: | |
| model_path = "models/mnist_model.keras" | |
| if "Auto" in model_choice: | |
| model_path = "models/autokeras_model.keras" | |
| if "Onnx" in model_choice: | |
| model_path = "models/mnist_12.onnx" | |
| with col3: | |
| # Create a canvas component | |
| canvas_result = st_canvas( | |
| stroke_width=stroke_width, | |
| stroke_color="#FFF", | |
| fill_color="#000", | |
| background_color="#000", | |
| background_image=None, | |
| update_streamlit=True, | |
| height=200, | |
| width=200, | |
| drawing_mode="freedraw", | |
| point_display_radius=0, | |
| key="canvas", | |
| ) | |
| if canvas_result is not None and canvas_result.image_data is not None: | |
| # Get the image data, convert it to grayscale, and resize it to 28x28 (the same size as the MNIST dataset images) | |
| img_data = canvas_result.image_data | |
| im = Image.fromarray(img_data.astype("uint8")).convert("L") | |
| im = im.resize((28, 28)) | |
| # Convert the image to a numpy array and normalize the values | |
| final = np.array(im, dtype=np.float32) / 255.0 | |
| # if final is not all zeros, run the prediction | |
| if not np.all(final == 0): | |
| if model_choice != "Onnx": | |
| prediction, pred_time, load_time = keras_prediction(final, model_path) | |
| else: | |
| prediction, pred_time, load_time = onnx_prediction(final, model_path) | |
| # print the prediction | |
| st.header(f"Using model: {model_choice}") | |
| st.write(f"Prediction: {np.argmax(prediction)}") | |
| st.write(f"Load time (in ms): {(load_time) * 1000:.2f}") | |
| st.write(f"Prediction time (in ms): {(pred_time) * 1000:.2f}") | |
| # Create a 2 column dataframe with one column as the digits and the other as the probability | |
| data = pd.DataFrame( | |
| {"Digit": list(range(10)), "Probability": np.ravel(prediction)} | |
| ) | |
| col1, col2 = st.columns([0.8, 0.2], gap="small") | |
| # create a bar chart to show the predictions | |
| with col1: | |
| st.bar_chart(data, x="Digit", y="Probability", height=500) | |
| # show the probability distribution numerically | |
| with col2: | |
| data["Probability"] = data["Probability"].apply(lambda x: f"{x:.2%}") | |
| st.dataframe(data, hide_index=True) | |
| if __name__ == "__main__": | |
| main() | |