another-demo / src /streamlit_app.py
Vincimus's picture
Auto-predict on canvas draw instead of button click
8604f36
import streamlit as st
import numpy as np
from PIL import Image, ImageOps
import joblib
import os
import plotly.graph_objects as go
from streamlit_drawable_canvas import st_canvas
# Page config
st.set_page_config(
page_title="Digit Recognizer",
page_icon="✏️",
layout="centered"
)
# Initialize session state for canvas clearing
if "canvas_key" not in st.session_state:
st.session_state.canvas_key = 0
# Load model
@st.cache_resource
def load_model():
model_path = os.path.join(os.path.dirname(__file__), "model", "digit_classifier.joblib")
return joblib.load(model_path)
def preprocess_canvas(image_data):
"""Process canvas image to match sklearn digits format (8x8, values 0-16)."""
# Convert to PIL Image (RGBA)
img = Image.fromarray(image_data.astype('uint8'), 'RGBA')
# Convert to grayscale
img = img.convert('L')
# Find bounding box of the drawing and crop with padding
bbox = img.getbbox()
if bbox:
# Add padding around the digit
padding = 20
left = max(0, bbox[0] - padding)
top = max(0, bbox[1] - padding)
right = min(img.width, bbox[2] + padding)
bottom = min(img.height, bbox[3] + padding)
img = img.crop((left, top, right, bottom))
# Make it square by padding the shorter dimension
max_dim = max(img.size)
new_img = Image.new('L', (max_dim, max_dim), 0)
offset = ((max_dim - img.size[0]) // 2, (max_dim - img.size[1]) // 2)
new_img.paste(img, offset)
img = new_img
# Resize to 8x8 using antialiasing
img = img.resize((8, 8), Image.Resampling.LANCZOS)
# Convert to numpy and scale to 0-16 range (sklearn digits format)
img_array = np.array(img, dtype=np.float64)
# Normalize: sklearn digits has 0=white(background), 16=black(digit ink)
# Our canvas has white stroke on black background, so white=255=digit
# Scale 0-255 to 0-16
img_array = (img_array / 255.0) * 16.0
return img_array.flatten().reshape(1, -1)
model = load_model()
# Title and instructions
st.title("✏️ Handwritten Digit Recognizer")
st.markdown("""
Draw a digit (0-9) in the canvas below. The prediction updates automatically!
*Tip: Draw the digit large and centered for best results.*
""")
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Draw Here")
# Drawing canvas - use session state key for clearing
canvas_result = st_canvas(
fill_color="black",
stroke_width=18,
stroke_color="white",
background_color="black",
height=280,
width=280,
drawing_mode="freedraw",
key=f"canvas_{st.session_state.canvas_key}",
)
# Clear button
if st.button("🗑️ Clear Canvas", use_container_width=True):
st.session_state.canvas_key += 1
st.rerun()
with col2:
st.subheader("Prediction")
# Check if canvas has any drawing
has_drawing = (
canvas_result.image_data is not None
and np.sum(canvas_result.image_data[:, :, :3]) > 0
)
if has_drawing:
# Preprocess the image to match sklearn digits format
img_flat = preprocess_canvas(canvas_result.image_data)
# Get prediction and probabilities
prediction = model.predict(img_flat)[0]
probabilities = model.predict_proba(img_flat)[0]
# Display large prediction
st.markdown(f"""
<div style="text-align: center; padding: 20px; background-color: #1e1e1e; border-radius: 10px; margin-bottom: 20px;">
<h1 style="font-size: 72px; margin: 0; color: #4CAF50;">{prediction}</h1>
<p style="font-size: 18px; color: #888;">Confidence: {probabilities[prediction]*100:.1f}%</p>
</div>
""", unsafe_allow_html=True)
# Probability chart
st.subheader("Confidence Scores")
# Create horizontal bar chart with Plotly
fig = go.Figure(go.Bar(
x=probabilities * 100,
y=[str(i) for i in range(10)],
orientation='h',
marker_color=['#4CAF50' if i == prediction else '#2196F3' for i in range(10)],
text=[f'{p*100:.1f}%' for p in probabilities],
textposition='outside'
))
fig.update_layout(
xaxis_title="Confidence (%)",
yaxis_title="Digit",
height=400,
margin=dict(l=20, r=20, t=20, b=40),
xaxis=dict(range=[0, 105]),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
)
st.plotly_chart(fig, use_container_width=True)
else:
st.info("👆 Draw a digit on the canvas to see the prediction")
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #888; font-size: 14px;">
<p>Built with Streamlit | Model trained on sklearn digits dataset (8x8 images)</p>
<p>The model is a Multi-Layer Perceptron (MLP) with ~97% accuracy</p>
</div>
""", unsafe_allow_html=True)