File size: 3,356 Bytes
f2db3f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
import numpy as np
import plotly.express as px

import streamlit as st
from streamlit_drawable_canvas import st_canvas

import cv2

from keras.models import load_model

# --- App Configuration ---
st.set_page_config(page_title="Handwritten Digit Recognizer", layout="centered")
st.markdown("""

    <style>

    .stButton>button {background-color: #4b7bec; color: white; border-radius: 8px; padding: 10px;}

    .stButton>button:hover {background-color: #3867d6;}

    </style>

""", unsafe_allow_html=True)

# --- Load Model ---
@st.cache_resource
def load_digit_model():
    return load_model("mnist_cnn_model_32x32_balanced_dataset.keras")

# --- App Title and Instructions ---
st.title("✍️ Handwritten Digit Recognizer")
st.markdown("Draw a digit (0–9) in the box below and click **Predict** to see the result!")

with st.expander("ℹ️ How to Use", expanded=False):
    st.markdown("""

    - Draw a digit clearly in the center of the canvas.

    - Use a thick stroke.

    - Click **Predict** to see the result.

    - Use **Clear Canvas** to start over.

    """)

# --- Drawing Canvas ---
st.subheader("Draw Your Digit")
canvas_result = st_canvas(
    fill_color="rgba(0, 0, 0, 0)",
    stroke_width=20,
    stroke_color="#000000",
    background_color="#FFFFFF",
    update_streamlit=True,
    height=280,
    width=280,
    drawing_mode="freedraw",
    key="canvas",
    display_toolbar=True,
)

# --- Predict Button ---
predict_clicked = st.button("🔍 Predict", use_container_width=True, key="predict_button")

if predict_clicked and canvas_result.image_data is not None:
    img = cv2.cvtColor(canvas_result.image_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY)
    
    if np.all(img == 255):
        st.warning("⚠️ Please draw something before predicting!")
    else:
        img = 255 - img  # Inverting the colors to mimic the dataset
        img_resized = cv2.resize(img, (32, 32), interpolation=cv2.INTER_AREA)
        img_normalized = img_resized.astype("float32") / 255.0
        input_img = img_normalized.reshape(1, 32, 32, 1)

        model = load_digit_model()
        pred_probs = model.predict(input_img)
        pred_class = np.argmax(pred_probs)
        confidence = np.max(pred_probs)

        st.subheader("Prediction Results")
        col_img, col_result = st.columns([1, 2])

        with col_img:
            st.image(img_resized, caption="Processed Drawing", width=100, clamp=True)

        with col_result:
            st.success(f"🧠 Predicted Digit: **{pred_class}**")
            st.info(f"🔍 Confidence: **{confidence * 100:.2f}%**")

        # --- Plot probabilities ---
        probs_df = pd.DataFrame({
            "Digit": list(range(10)),
            "Probability": pred_probs[0] * 100
        })
        fig = px.bar(probs_df, x="Digit", y="Probability", 
                     title="Prediction Probabilities",
                     color="Probability",
                     color_continuous_scale="Blues",
                     height=300)
        fig.update_layout(xaxis_title="Digit", yaxis_title="Probability (%)", xaxis=dict(tickmode="linear"))
        st.plotly_chart(fig, use_container_width=True)

elif predict_clicked:
    st.warning("⚠️ Please draw something before predicting!")