Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageOps | |
| from tensorflow.keras.utils import img_to_array | |
| from streamlit_drawable_canvas import st_canvas | |
| import streamlit as st | |
| # st.set_page_config(layout="wide") | |
| st.write('# MNIST Digit Recognition') | |
| st.write('## Using trained CNN `Keras` model') | |
| st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook') | |
| # Import Pre-trained Model | |
| model = tf.keras.models.load_model('mnist.h5') | |
| tf.device('/cpu:0') | |
| plt.rcParams.update({'font.size': 18}) | |
| # Create a sidebar to hold the settings | |
| stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9) | |
| realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
| stroke_width=stroke_width, | |
| stroke_color='#FFFFFF', | |
| background_color='#000000', | |
| #background_image=Image.open(bg_image) if bg_image else None, | |
| update_streamlit=realtime_update, | |
| height=28*9, | |
| width=28*9, | |
| drawing_mode='freedraw', | |
| key="canvas", | |
| ) | |
| if canvas_result.image_data is not None: | |
| # Get image data from canvas | |
| im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype( | |
| 'uint8'), mode="RGBA")).resize((28, 28)) | |
| # Convert image to array and reshape | |
| data = img_to_array(im) | |
| data = data / 255 | |
| data = data.reshape(1, 28, 28, 1) | |
| data = data.astype('float32') | |
| # Predict digit | |
| st.write('### Predicted Digit') | |
| prediction = model.predict(data) | |
| # Plot prediction | |
| result = plt.figure(figsize=(12, 3)) | |
| plt.bar(range(10), prediction[0]) | |
| plt.xticks(range(10)) | |
| plt.xlabel('Digit') | |
| plt.ylabel('Probability') | |
| plt.title('Drawing Prediction') | |
| plt.ylim(0, 1) | |
| st.write(result) | |
| # Show resized image | |
| with st.expander('Show Resized Image'): | |
| st.write( | |
| "The image needs to be resized, because it can only input 28x28 images") | |
| st.image(im, caption='Resized Image', width=28*9) | |