Spaces:
Sleeping
Sleeping
File size: 5,042 Bytes
dee4be0 cd8e368 92007c5 cd8e368 92007c5 cd8e368 92007c5 cd8e368 8604f36 cd8e368 92007c5 cd8e368 92007c5 cd8e368 92007c5 cd8e368 8604f36 cd8e368 8604f36 cd8e368 8604f36 dee4be0 cd8e368 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import streamlit as st
import numpy as np
from PIL import Image, ImageOps
import joblib
import os
import plotly.graph_objects as go
from streamlit_drawable_canvas import st_canvas
# Page config
st.set_page_config(
page_title="Digit Recognizer",
page_icon="✏️",
layout="centered"
)
# Initialize session state for canvas clearing
if "canvas_key" not in st.session_state:
st.session_state.canvas_key = 0
# Load model
@st.cache_resource
def load_model():
model_path = os.path.join(os.path.dirname(__file__), "model", "digit_classifier.joblib")
return joblib.load(model_path)
def preprocess_canvas(image_data):
"""Process canvas image to match sklearn digits format (8x8, values 0-16)."""
# Convert to PIL Image (RGBA)
img = Image.fromarray(image_data.astype('uint8'), 'RGBA')
# Convert to grayscale
img = img.convert('L')
# Find bounding box of the drawing and crop with padding
bbox = img.getbbox()
if bbox:
# Add padding around the digit
padding = 20
left = max(0, bbox[0] - padding)
top = max(0, bbox[1] - padding)
right = min(img.width, bbox[2] + padding)
bottom = min(img.height, bbox[3] + padding)
img = img.crop((left, top, right, bottom))
# Make it square by padding the shorter dimension
max_dim = max(img.size)
new_img = Image.new('L', (max_dim, max_dim), 0)
offset = ((max_dim - img.size[0]) // 2, (max_dim - img.size[1]) // 2)
new_img.paste(img, offset)
img = new_img
# Resize to 8x8 using antialiasing
img = img.resize((8, 8), Image.Resampling.LANCZOS)
# Convert to numpy and scale to 0-16 range (sklearn digits format)
img_array = np.array(img, dtype=np.float64)
# Normalize: sklearn digits has 0=white(background), 16=black(digit ink)
# Our canvas has white stroke on black background, so white=255=digit
# Scale 0-255 to 0-16
img_array = (img_array / 255.0) * 16.0
return img_array.flatten().reshape(1, -1)
model = load_model()
# Title and instructions
st.title("✏️ Handwritten Digit Recognizer")
st.markdown("""
Draw a digit (0-9) in the canvas below. The prediction updates automatically!
*Tip: Draw the digit large and centered for best results.*
""")
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Draw Here")
# Drawing canvas - use session state key for clearing
canvas_result = st_canvas(
fill_color="black",
stroke_width=18,
stroke_color="white",
background_color="black",
height=280,
width=280,
drawing_mode="freedraw",
key=f"canvas_{st.session_state.canvas_key}",
)
# Clear button
if st.button("🗑️ Clear Canvas", use_container_width=True):
st.session_state.canvas_key += 1
st.rerun()
with col2:
st.subheader("Prediction")
# Check if canvas has any drawing
has_drawing = (
canvas_result.image_data is not None
and np.sum(canvas_result.image_data[:, :, :3]) > 0
)
if has_drawing:
# Preprocess the image to match sklearn digits format
img_flat = preprocess_canvas(canvas_result.image_data)
# Get prediction and probabilities
prediction = model.predict(img_flat)[0]
probabilities = model.predict_proba(img_flat)[0]
# Display large prediction
st.markdown(f"""
<div style="text-align: center; padding: 20px; background-color: #1e1e1e; border-radius: 10px; margin-bottom: 20px;">
<h1 style="font-size: 72px; margin: 0; color: #4CAF50;">{prediction}</h1>
<p style="font-size: 18px; color: #888;">Confidence: {probabilities[prediction]*100:.1f}%</p>
</div>
""", unsafe_allow_html=True)
# Probability chart
st.subheader("Confidence Scores")
# Create horizontal bar chart with Plotly
fig = go.Figure(go.Bar(
x=probabilities * 100,
y=[str(i) for i in range(10)],
orientation='h',
marker_color=['#4CAF50' if i == prediction else '#2196F3' for i in range(10)],
text=[f'{p*100:.1f}%' for p in probabilities],
textposition='outside'
))
fig.update_layout(
xaxis_title="Confidence (%)",
yaxis_title="Digit",
height=400,
margin=dict(l=20, r=20, t=20, b=40),
xaxis=dict(range=[0, 105]),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
)
st.plotly_chart(fig, use_container_width=True)
else:
st.info("👆 Draw a digit on the canvas to see the prediction")
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #888; font-size: 14px;">
<p>Built with Streamlit | Model trained on sklearn digits dataset (8x8 images)</p>
<p>The model is a Multi-Layer Perceptron (MLP) with ~97% accuracy</p>
</div>
""", unsafe_allow_html=True)
|