MNISTify / app.py
NeonSamurai's picture
Upload 2 files
f2db3f8 verified
import pandas as pd
import numpy as np
import plotly.express as px
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import cv2
from keras.models import load_model
# --- App Configuration ---
st.set_page_config(page_title="Handwritten Digit Recognizer", layout="centered")
st.markdown("""
<style>
.stButton>button {background-color: #4b7bec; color: white; border-radius: 8px; padding: 10px;}
.stButton>button:hover {background-color: #3867d6;}
</style>
""", unsafe_allow_html=True)
# --- Load Model ---
@st.cache_resource
def load_digit_model():
return load_model("mnist_cnn_model_32x32_balanced_dataset.keras")
# --- App Title and Instructions ---
st.title("โœ๏ธ Handwritten Digit Recognizer")
st.markdown("Draw a digit (0โ€“9) in the box below and click **Predict** to see the result!")
with st.expander("โ„น๏ธ How to Use", expanded=False):
st.markdown("""
- Draw a digit clearly in the center of the canvas.
- Use a thick stroke.
- Click **Predict** to see the result.
- Use **Clear Canvas** to start over.
""")
# --- Drawing Canvas ---
st.subheader("Draw Your Digit")
canvas_result = st_canvas(
fill_color="rgba(0, 0, 0, 0)",
stroke_width=20,
stroke_color="#000000",
background_color="#FFFFFF",
update_streamlit=True,
height=280,
width=280,
drawing_mode="freedraw",
key="canvas",
display_toolbar=True,
)
# --- Predict Button ---
predict_clicked = st.button("๐Ÿ” Predict", use_container_width=True, key="predict_button")
if predict_clicked and canvas_result.image_data is not None:
img = cv2.cvtColor(canvas_result.image_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY)
if np.all(img == 255):
st.warning("โš ๏ธ Please draw something before predicting!")
else:
img = 255 - img # Inverting the colors to mimic the dataset
img_resized = cv2.resize(img, (32, 32), interpolation=cv2.INTER_AREA)
img_normalized = img_resized.astype("float32") / 255.0
input_img = img_normalized.reshape(1, 32, 32, 1)
model = load_digit_model()
pred_probs = model.predict(input_img)
pred_class = np.argmax(pred_probs)
confidence = np.max(pred_probs)
st.subheader("Prediction Results")
col_img, col_result = st.columns([1, 2])
with col_img:
st.image(img_resized, caption="Processed Drawing", width=100, clamp=True)
with col_result:
st.success(f"๐Ÿง  Predicted Digit: **{pred_class}**")
st.info(f"๐Ÿ” Confidence: **{confidence * 100:.2f}%**")
# --- Plot probabilities ---
probs_df = pd.DataFrame({
"Digit": list(range(10)),
"Probability": pred_probs[0] * 100
})
fig = px.bar(probs_df, x="Digit", y="Probability",
title="Prediction Probabilities",
color="Probability",
color_continuous_scale="Blues",
height=300)
fig.update_layout(xaxis_title="Digit", yaxis_title="Probability (%)", xaxis=dict(tickmode="linear"))
st.plotly_chart(fig, use_container_width=True)
elif predict_clicked:
st.warning("โš ๏ธ Please draw something before predicting!")