Sreeja6600's picture
Update app.py
7d4e951 verified
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from keras.models import load_model
import numpy as np
import cv2
from PIL import Image
# Unique title for the app
st.title("Handwritten Digit Recognizer")
# Sidebar controls
drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
stroke_color = st.sidebar.color_picker("Stroke color hex: ","#FFFFFF")
bg_color = st.sidebar.color_picker("Background color hex: ","#000000" )
uploaded_digit_image = st.sidebar.file_uploader("Upload a digit image (for prediction):", type=["png", "jpg"])
realtime_update = st.sidebar.checkbox("Update in realtime", True)
# Load model
@st.cache_resource
def load_mnist_model():
return load_model('digit_reco.keras')
model = load_mnist_model()
# Improved Preprocessing function
def preprocess_image(image_data):
img_rgb = cv2.cvtColor(image_data.astype("uint8"), cv2.COLOR_RGBA2RGB)
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
inverted = cv2.bitwise_not(gray)
_, binary = cv2.threshold(inverted, 50, 255, cv2.THRESH_BINARY)
coords = cv2.findNonZero(binary)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
cropped = binary[y:y+h, x:x+w]
else:
return np.zeros((28, 28))
height, width = cropped.shape
if height > width:
new_height = 20
new_width = int(width * (20.0 / height))
else:
new_width = 20
new_height = int(height * (20.0 / width))
resized = cv2.resize(cropped, (new_width, new_height), interpolation=cv2.INTER_AREA)
padded = np.zeros((28, 28), dtype=np.uint8)
x_offset = (28 - resized.shape[1]) // 2
y_offset = (28 - resized.shape[0]) // 2
padded[y_offset:y_offset+resized.shape[0], x_offset:x_offset+resized.shape[1]] = resized
normalized = padded / 255.0
return normalized
# Handle uploaded image prediction
if uploaded_digit_image is not None:
st.subheader("📤 Uploaded Image Prediction")
image = Image.open(uploaded_digit_image).convert("RGBA")
image = image.resize((280, 280))
img_array = np.array(image)
st.image(img_array, caption="Uploaded Image")
processed_image = preprocess_image(img_array)
st.image(processed_image, width=150, caption="Processed Input (28x28)")
img_reshaped = processed_image.reshape(1, 28, 28, 1)
prediction = model.predict(img_reshaped)
predicted_digit = int(np.argmax(prediction))
st.markdown(
f"""
<div style='text-align: center; font-size: 60px; font-weight: bold; color: #2E86C1;
text-shadow: 2px 2px 4px #aaa; margin-top: 20px;'>
Predicted Digit: {predicted_digit}
</div>
""",
unsafe_allow_html=True
)
# Handle canvas drawing
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
update_streamlit=realtime_update,
height=280,
width=280,
drawing_mode=drawing_mode,
key="canvas",
)
if canvas_result.image_data is not None:
st.subheader("✏️ Canvas Drawing Prediction")
st.image(canvas_result.image_data, caption="Original Drawing")
processed_image = preprocess_image(canvas_result.image_data)
st.image(processed_image, width=150, caption="Processed Input (28x28)")
img_reshaped = processed_image.reshape(1, 28, 28, 1)
prediction = model.predict(img_reshaped)
predicted_digit = int(np.argmax(prediction))
st.markdown(
f"""
<div style='text-align: center; font-size: 60px; font-weight: bold; color: #2E86C1;
text-shadow: 2px 2px 4px #aaa; margin-top: 20px;'>
Predicted Digit: {predicted_digit}
</div>
""",
unsafe_allow_html=True
)