Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import joblib | |
| import os | |
| import plotly.graph_objects as go | |
| from streamlit_drawable_canvas import st_canvas | |
| # Page config | |
| st.set_page_config( | |
| page_title="Digit Recognizer", | |
| page_icon="✏️", | |
| layout="centered" | |
| ) | |
| # Initialize session state for canvas clearing | |
| if "canvas_key" not in st.session_state: | |
| st.session_state.canvas_key = 0 | |
| # Load model | |
| def load_model(): | |
| model_path = os.path.join(os.path.dirname(__file__), "model", "digit_classifier.joblib") | |
| return joblib.load(model_path) | |
| def preprocess_canvas(image_data): | |
| """Process canvas image to match sklearn digits format (8x8, values 0-16).""" | |
| # Convert to PIL Image (RGBA) | |
| img = Image.fromarray(image_data.astype('uint8'), 'RGBA') | |
| # Convert to grayscale | |
| img = img.convert('L') | |
| # Find bounding box of the drawing and crop with padding | |
| bbox = img.getbbox() | |
| if bbox: | |
| # Add padding around the digit | |
| padding = 20 | |
| left = max(0, bbox[0] - padding) | |
| top = max(0, bbox[1] - padding) | |
| right = min(img.width, bbox[2] + padding) | |
| bottom = min(img.height, bbox[3] + padding) | |
| img = img.crop((left, top, right, bottom)) | |
| # Make it square by padding the shorter dimension | |
| max_dim = max(img.size) | |
| new_img = Image.new('L', (max_dim, max_dim), 0) | |
| offset = ((max_dim - img.size[0]) // 2, (max_dim - img.size[1]) // 2) | |
| new_img.paste(img, offset) | |
| img = new_img | |
| # Resize to 8x8 using antialiasing | |
| img = img.resize((8, 8), Image.Resampling.LANCZOS) | |
| # Convert to numpy and scale to 0-16 range (sklearn digits format) | |
| img_array = np.array(img, dtype=np.float64) | |
| # Normalize: sklearn digits has 0=white(background), 16=black(digit ink) | |
| # Our canvas has white stroke on black background, so white=255=digit | |
| # Scale 0-255 to 0-16 | |
| img_array = (img_array / 255.0) * 16.0 | |
| return img_array.flatten().reshape(1, -1) | |
| model = load_model() | |
| # Title and instructions | |
| st.title("✏️ Handwritten Digit Recognizer") | |
| st.markdown(""" | |
| Draw a digit (0-9) in the canvas below. The prediction updates automatically! | |
| *Tip: Draw the digit large and centered for best results.* | |
| """) | |
| # Create two columns for layout | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("Draw Here") | |
| # Drawing canvas - use session state key for clearing | |
| canvas_result = st_canvas( | |
| fill_color="black", | |
| stroke_width=18, | |
| stroke_color="white", | |
| background_color="black", | |
| height=280, | |
| width=280, | |
| drawing_mode="freedraw", | |
| key=f"canvas_{st.session_state.canvas_key}", | |
| ) | |
| # Clear button | |
| if st.button("🗑️ Clear Canvas", use_container_width=True): | |
| st.session_state.canvas_key += 1 | |
| st.rerun() | |
| with col2: | |
| st.subheader("Prediction") | |
| # Check if canvas has any drawing | |
| has_drawing = ( | |
| canvas_result.image_data is not None | |
| and np.sum(canvas_result.image_data[:, :, :3]) > 0 | |
| ) | |
| if has_drawing: | |
| # Preprocess the image to match sklearn digits format | |
| img_flat = preprocess_canvas(canvas_result.image_data) | |
| # Get prediction and probabilities | |
| prediction = model.predict(img_flat)[0] | |
| probabilities = model.predict_proba(img_flat)[0] | |
| # Display large prediction | |
| st.markdown(f""" | |
| <div style="text-align: center; padding: 20px; background-color: #1e1e1e; border-radius: 10px; margin-bottom: 20px;"> | |
| <h1 style="font-size: 72px; margin: 0; color: #4CAF50;">{prediction}</h1> | |
| <p style="font-size: 18px; color: #888;">Confidence: {probabilities[prediction]*100:.1f}%</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Probability chart | |
| st.subheader("Confidence Scores") | |
| # Create horizontal bar chart with Plotly | |
| fig = go.Figure(go.Bar( | |
| x=probabilities * 100, | |
| y=[str(i) for i in range(10)], | |
| orientation='h', | |
| marker_color=['#4CAF50' if i == prediction else '#2196F3' for i in range(10)], | |
| text=[f'{p*100:.1f}%' for p in probabilities], | |
| textposition='outside' | |
| )) | |
| fig.update_layout( | |
| xaxis_title="Confidence (%)", | |
| yaxis_title="Digit", | |
| height=400, | |
| margin=dict(l=20, r=20, t=20, b=40), | |
| xaxis=dict(range=[0, 105]), | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| font=dict(color='white') | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.info("👆 Draw a digit on the canvas to see the prediction") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style="text-align: center; color: #888; font-size: 14px;"> | |
| <p>Built with Streamlit | Model trained on sklearn digits dataset (8x8 images)</p> | |
| <p>The model is a Multi-Layer Perceptron (MLP) with ~97% accuracy</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |