import streamlit as st from streamlit_drawable_canvas import st_canvas from keras.models import load_model import numpy as np import cv2 from PIL import Image # Unique title for the app st.title("Handwritten Digit Recognizer") # Sidebar controls drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")) stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10) stroke_color = st.sidebar.color_picker("Stroke color hex: ","#FFFFFF") bg_color = st.sidebar.color_picker("Background color hex: ","#000000" ) uploaded_digit_image = st.sidebar.file_uploader("Upload a digit image (for prediction):", type=["png", "jpg"]) realtime_update = st.sidebar.checkbox("Update in realtime", True) # Load model @st.cache_resource def load_mnist_model(): return load_model('digit_reco.keras') model = load_mnist_model() # Improved Preprocessing function def preprocess_image(image_data): img_rgb = cv2.cvtColor(image_data.astype("uint8"), cv2.COLOR_RGBA2RGB) gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY) inverted = cv2.bitwise_not(gray) _, binary = cv2.threshold(inverted, 50, 255, cv2.THRESH_BINARY) coords = cv2.findNonZero(binary) if coords is not None: x, y, w, h = cv2.boundingRect(coords) cropped = binary[y:y+h, x:x+w] else: return np.zeros((28, 28)) height, width = cropped.shape if height > width: new_height = 20 new_width = int(width * (20.0 / height)) else: new_width = 20 new_height = int(height * (20.0 / width)) resized = cv2.resize(cropped, (new_width, new_height), interpolation=cv2.INTER_AREA) padded = np.zeros((28, 28), dtype=np.uint8) x_offset = (28 - resized.shape[1]) // 2 y_offset = (28 - resized.shape[0]) // 2 padded[y_offset:y_offset+resized.shape[0], x_offset:x_offset+resized.shape[1]] = resized normalized = padded / 255.0 return normalized # Handle uploaded image prediction if uploaded_digit_image is not None: st.subheader("📤 Uploaded Image Prediction") image = Image.open(uploaded_digit_image).convert("RGBA") image = image.resize((280, 280)) img_array = np.array(image) st.image(img_array, caption="Uploaded Image") processed_image = preprocess_image(img_array) st.image(processed_image, width=150, caption="Processed Input (28x28)") img_reshaped = processed_image.reshape(1, 28, 28, 1) prediction = model.predict(img_reshaped) predicted_digit = int(np.argmax(prediction)) st.markdown( f"""
Predicted Digit: {predicted_digit}
""", unsafe_allow_html=True ) # Handle canvas drawing canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", stroke_width=stroke_width, stroke_color=stroke_color, background_color=bg_color, update_streamlit=realtime_update, height=280, width=280, drawing_mode=drawing_mode, key="canvas", ) if canvas_result.image_data is not None: st.subheader("✏️ Canvas Drawing Prediction") st.image(canvas_result.image_data, caption="Original Drawing") processed_image = preprocess_image(canvas_result.image_data) st.image(processed_image, width=150, caption="Processed Input (28x28)") img_reshaped = processed_image.reshape(1, 28, 28, 1) prediction = model.predict(img_reshaped) predicted_digit = int(np.argmax(prediction)) st.markdown( f"""
Predicted Digit: {predicted_digit}
""", unsafe_allow_html=True )