| # import streamlit as st | |
| # import cv2 | |
| # from streamlit_drawable_canvas import st_canvas | |
| # from keras.models import load_model | |
| # import numpy as np | |
| # # Sidebar controls | |
| # st.sidebar.title("Canvas Settings") | |
| # drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")) | |
| # stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10) | |
| # stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black | |
| # bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white | |
| # bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"]) | |
| # realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
| # # Load model with caching | |
| # @st.cache_resource | |
| # def load_mnist_model(): | |
| # return load_model("mnist_model.keras") | |
| # model = load_mnist_model() | |
| # st.title("ποΈ Mindist: Draw a Number, Predict Instantly") | |
| # # Create a two-column layout | |
| # col1, col2 = st.columns([1, 1]) | |
| # with col1: | |
| # st.subheader("Draw Here π") | |
| # canvas_result = st_canvas( | |
| # fill_color="rgba(255, 165, 0, 0.3)", | |
| # stroke_width=stroke_width, | |
| # stroke_color=stroke_color, | |
| # background_color=bg_color, | |
| # update_streamlit=realtime_update, | |
| # height=280, | |
| # width=280, | |
| # drawing_mode=drawing_mode, | |
| # key="canvas", | |
| # ) | |
| # with col2: | |
| # if canvas_result.image_data is not None: | |
| # st.subheader("Original Drawing") | |
| # st.image(canvas_result.image_data, use_column_width=True) | |
| # # Below the two columns: Show preprocessing and prediction | |
| # if canvas_result.image_data is not None: | |
| # st.markdown("---") | |
| # st.subheader("Preprocessed Image & Prediction") | |
| # img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY) | |
| # img = 255 - img # Invert colors | |
| # img_resized = cv2.resize(img, (28, 28)) | |
| # img_normalized = img_resized / 255.0 | |
| # final_img = img_normalized.reshape(1, 28, 28, 1) | |
| # col3, col4 = st.columns([1, 1]) | |
| # with col3: | |
| # st.image(img_resized, caption="28x28 Preprocessed", clamp=True, channels="GRAY") | |
| # with col4: | |
| # prediction = model.predict(final_img) | |
| # predicted_digit = np.argmax(prediction) | |
| # st.markdown(f"### π§ Predicted Digit: **{predicted_digit}**") | |
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| from keras.models import load_model | |
| from streamlit_drawable_canvas import st_canvas | |
| # === Load models === | |
| def load_single_digit_model(): | |
| return load_model("mnist_model.keras") | |
| def load_multi_digit_model(): | |
| return load_model("best_model.keras") # multi-digit model | |
| single_digit_model = load_single_digit_model() | |
| multi_digit_model = load_multi_digit_model() | |
| # === Helper function to clean prediction === | |
| def clean_prediction(predicted_digits): | |
| """ | |
| Removes junk or padded digits like trailing 0s or 1s and keeps only valid 0β9 digits. | |
| You can further tune this logic based on training patterns. | |
| """ | |
| digits = [str(d) for d in predicted_digits if 0 <= d <= 9] | |
| return ''.join(digits) | |
| # === Sidebar controls === | |
| st.sidebar.title("Canvas Settings") | |
| drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")) | |
| stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10) | |
| stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black | |
| bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white | |
| bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"]) | |
| realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
| # === Title === | |
| st.title("ποΈ Multi-Digit and Single-Digit Drawing: Predict Instantly") | |
| # === Create a two-column layout === | |
| col1, col2 = st.columns([1, 1]) | |
| # === Canvas for drawing === | |
| with col1: | |
| st.subheader("Draw Here π") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", | |
| stroke_width=stroke_width, | |
| stroke_color=stroke_color, | |
| background_color=bg_color, | |
| update_streamlit=realtime_update, | |
| height=280, | |
| width=280, | |
| drawing_mode=drawing_mode, | |
| key="canvas", | |
| ) | |
| # === Display original drawing === | |
| with col2: | |
| if canvas_result.image_data is not None: | |
| st.subheader("Original Drawing") | |
| st.image(canvas_result.image_data, use_column_width=True) | |
| # === Image preprocessing and prediction === | |
| if canvas_result.image_data is not None: | |
| st.markdown("---") | |
| st.subheader("Preprocessed Image & Prediction") | |
| # Preprocess image | |
| img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY) | |
| img = 255 - img # Invert colors | |
| # Resize image to match model input dimensions | |
| img_resized = cv2.resize(img, (80, 28)) # Resize to match multi-digit model input shape | |
| img_normalized = img_resized / 255.0 | |
| final_img = img_normalized.reshape(1, 28, 80, 1) | |
| # === Choose which model to use based on the image size === | |
| # If image is more likely to be a single digit (e.g., smaller width), use the single digit model | |
| if img_resized.shape[1] < 50: # This is an arbitrary threshold for width | |
| model_to_use = single_digit_model | |
| else: | |
| model_to_use = multi_digit_model | |
| # Predict using the selected model | |
| preds = model_to_use.predict(final_img) | |
| # For multi-digit model, decode and clean prediction | |
| if model_to_use == multi_digit_model: | |
| predicted_digits = [np.argmax(p[0]) for p in preds] | |
| predicted_str = clean_prediction(predicted_digits) | |
| else: | |
| # For single digit model, directly decode | |
| predicted_str = str(np.argmax(preds)) | |
| # Show prediction result | |
| st.markdown(f"### π§ Predicted Number: **{predicted_str}**") | |