File size: 2,263 Bytes
2fabd6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea7ff1
2fabd6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea7ff1
2fabd6e
 
3ea7ff1
 
2fabd6e
 
 
3ea7ff1
2fabd6e
 
3ea7ff1
2fabd6e
 
 
 
3ea7ff1
2fabd6e
3ea7ff1
2fabd6e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model


# Install the streamlit_drawable_canvas package if you haven't already
# !pip install streamlit_drawable_canvas

# Import the st_canvas function
from streamlit_drawable_canvas import st_canvas

# Function to preprocess the drawn image
def preprocess_image(drawing, size=(28, 28)):
    # Convert the drawing to a PIL Image
    img = Image.fromarray(np.uint8(drawing))
    # Resize the image to the desired size
    img = img.resize(size)
    # Convert the image to grayscale
    img = img.convert('L')
    # Convert the image to a numpy array
    img_array = np.array(img)
    # Normalize the pixel values to be between 0 and 1
    img_array = img_array / 255.0
    # Add a channel dimension (1 channel for grayscale)
    img_array = np.expand_dims(img_array, axis=-1)
    return img_array

def preprocess_and_predict(image):
    model = load_model("mnist_cnn_model.h5")
    # Expand dimensions to match the input shape expected by the model
    image = np.expand_dims(image, axis=0)
    # Reshape to match the input shape expected by the model
    image = np.reshape(image, (1, 28, 28, 1))
    prediction = model.predict(image)
    predicted_class = np.argmax(prediction)

    return predicted_class

# Main code
def main():
    st.title('Draw Digit')

    # Create a drawing canvas
    drawing = st_canvas(
        fill_color="rgb(0, 0, 0)",  # Background color of the canvas
        stroke_width=4,  # Stroke width
        stroke_color="rgb(255, 255, 255)",  # Stroke color
        background_color="#000000",  # Background color of the canvas component
        height=168,  # Height of the canvas
        width=168,  # Width of the canvas
        drawing_mode="freedraw",  # Drawing mode: "freedraw" or "transform"
        key="canvas",
    )
    predict = st.button('Predict digit')

    # Check if the user has drawn anything
    if predict is True:

        # Preprocess the drawn image
        processed_image = preprocess_image(drawing.image_data)
        digit_class = preprocess_and_predict(processed_image)
        st.title("Predicted Digit:")
        st.success(digit_class)
        predict = False

if __name__ == "__main__":
    main()