Digital_recognizer / src /streamlit_app.py
satya11's picture
Update src/streamlit_app.py
fc32d28 verified
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from keras.models import load_model
import numpy as np
import cv2
import os
# Set page layout
st.set_page_config(page_title="Digit AI", layout="centered")
# Custom styles
st.markdown("""
<style>
.canvas-wrapper {
border: 2px dashed #aaa;
padding: 10px;
margin-bottom: 10px;
}
.prediction-box {
font-size: 28px;
font-weight: bold;
margin-top: 10px;
}
.emoji {
font-size: 48px;
}
</style>
""", unsafe_allow_html=True)
# App title
st.markdown("<h1>Digit Recognizer</h1>", unsafe_allow_html=True)
st.markdown("<p>Draw a digit (0–9) below and see what the AI thinks it is!</p>", unsafe_allow_html=True)
# Sidebar options
st.sidebar.markdown("### ✏️ Drawing Settings")
drawing_mode = st.sidebar.selectbox("Tool", ("freedraw", "line", "rect", "circle", "transform"))
stroke_width = st.sidebar.slider("Stroke Width", 1, 25, 10)
stroke_color = st.sidebar.color_picker("Stroke Color", "#FFFFFF")
bg_color = st.sidebar.color_picker("Background Color", "#000000")
realtime_update = st.sidebar.checkbox("Update Realtime", True)
# Load the model safely
@st.cache_resource
def load_mnist_model():
model_path = os.path.join("src", "digit_recognization.keras")
if not os.path.exists(model_path):
st.error(f"❌ Model file not found at: {model_path}")
st.stop()
try:
return load_model(model_path)
except Exception as e:
st.error(f"❌ Failed to load model: {e}")
st.stop()
model = load_mnist_model()
# Canvas for user input
st.markdown('<div class="canvas-wrapper">', unsafe_allow_html=True)
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0.05)",
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
update_streamlit=realtime_update,
height=280,
width=280,
drawing_mode=drawing_mode,
key="canvas"
)
st.markdown('</div>', unsafe_allow_html=True)
# Process the drawn image
if canvas_result.image_data is not None:
img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
img_resized = cv2.resize(img, (28, 28))
img_normalized = img_resized / 255.0
img_reshaped = img_normalized.reshape((1, 28, 28, 1))
if np.sum(img_resized) > 10:
prediction = model.predict(img_reshaped, verbose=0)
predicted_digit = np.argmax(prediction)
emoji_digits = ['0️⃣','1️⃣','2️⃣','3️⃣','4️⃣','5️⃣','6️⃣','7️⃣','8️⃣','9️⃣']
st.markdown(f"<div class='prediction-box'>Prediction: {predicted_digit}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='emoji'>{emoji_digits[predicted_digit]}</div>", unsafe_allow_html=True)
else:
st.warning("Please draw a digit before predicting.")