File size: 5,042 Bytes
dee4be0
cd8e368
92007c5
cd8e368
 
 
 
 
 
 
 
 
 
 
 
92007c5
 
 
 
cd8e368
 
 
 
 
 
92007c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd8e368
 
 
 
 
8604f36
cd8e368
 
 
 
 
 
 
 
 
 
92007c5
cd8e368
 
92007c5
cd8e368
 
 
 
 
92007c5
cd8e368
 
8604f36
 
 
 
cd8e368
 
 
 
8604f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd8e368
8604f36
dee4be0
cd8e368
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
import numpy as np
from PIL import Image, ImageOps
import joblib
import os
import plotly.graph_objects as go
from streamlit_drawable_canvas import st_canvas

# Page config
st.set_page_config(
    page_title="Digit Recognizer",
    page_icon="✏️",
    layout="centered"
)

# Initialize session state for canvas clearing
if "canvas_key" not in st.session_state:
    st.session_state.canvas_key = 0

# Load model
@st.cache_resource
def load_model():
    model_path = os.path.join(os.path.dirname(__file__), "model", "digit_classifier.joblib")
    return joblib.load(model_path)

def preprocess_canvas(image_data):
    """Process canvas image to match sklearn digits format (8x8, values 0-16)."""
    # Convert to PIL Image (RGBA)
    img = Image.fromarray(image_data.astype('uint8'), 'RGBA')

    # Convert to grayscale
    img = img.convert('L')

    # Find bounding box of the drawing and crop with padding
    bbox = img.getbbox()
    if bbox:
        # Add padding around the digit
        padding = 20
        left = max(0, bbox[0] - padding)
        top = max(0, bbox[1] - padding)
        right = min(img.width, bbox[2] + padding)
        bottom = min(img.height, bbox[3] + padding)
        img = img.crop((left, top, right, bottom))

    # Make it square by padding the shorter dimension
    max_dim = max(img.size)
    new_img = Image.new('L', (max_dim, max_dim), 0)
    offset = ((max_dim - img.size[0]) // 2, (max_dim - img.size[1]) // 2)
    new_img.paste(img, offset)
    img = new_img

    # Resize to 8x8 using antialiasing
    img = img.resize((8, 8), Image.Resampling.LANCZOS)

    # Convert to numpy and scale to 0-16 range (sklearn digits format)
    img_array = np.array(img, dtype=np.float64)

    # Normalize: sklearn digits has 0=white(background), 16=black(digit ink)
    # Our canvas has white stroke on black background, so white=255=digit
    # Scale 0-255 to 0-16
    img_array = (img_array / 255.0) * 16.0

    return img_array.flatten().reshape(1, -1)

model = load_model()

# Title and instructions
st.title("✏️ Handwritten Digit Recognizer")
st.markdown("""
Draw a digit (0-9) in the canvas below. The prediction updates automatically!

*Tip: Draw the digit large and centered for best results.*
""")

# Create two columns for layout
col1, col2 = st.columns([1, 1])

with col1:
    st.subheader("Draw Here")

    # Drawing canvas - use session state key for clearing
    canvas_result = st_canvas(
        fill_color="black",
        stroke_width=18,
        stroke_color="white",
        background_color="black",
        height=280,
        width=280,
        drawing_mode="freedraw",
        key=f"canvas_{st.session_state.canvas_key}",
    )

    # Clear button
    if st.button("🗑️ Clear Canvas", use_container_width=True):
        st.session_state.canvas_key += 1
        st.rerun()

with col2:
    st.subheader("Prediction")

    # Check if canvas has any drawing
    has_drawing = (
        canvas_result.image_data is not None
        and np.sum(canvas_result.image_data[:, :, :3]) > 0
    )

    if has_drawing:
        # Preprocess the image to match sklearn digits format
        img_flat = preprocess_canvas(canvas_result.image_data)

        # Get prediction and probabilities
        prediction = model.predict(img_flat)[0]
        probabilities = model.predict_proba(img_flat)[0]

        # Display large prediction
        st.markdown(f"""
        <div style="text-align: center; padding: 20px; background-color: #1e1e1e; border-radius: 10px; margin-bottom: 20px;">
            <h1 style="font-size: 72px; margin: 0; color: #4CAF50;">{prediction}</h1>
            <p style="font-size: 18px; color: #888;">Confidence: {probabilities[prediction]*100:.1f}%</p>
        </div>
        """, unsafe_allow_html=True)

        # Probability chart
        st.subheader("Confidence Scores")

        # Create horizontal bar chart with Plotly
        fig = go.Figure(go.Bar(
            x=probabilities * 100,
            y=[str(i) for i in range(10)],
            orientation='h',
            marker_color=['#4CAF50' if i == prediction else '#2196F3' for i in range(10)],
            text=[f'{p*100:.1f}%' for p in probabilities],
            textposition='outside'
        ))

        fig.update_layout(
            xaxis_title="Confidence (%)",
            yaxis_title="Digit",
            height=400,
            margin=dict(l=20, r=20, t=20, b=40),
            xaxis=dict(range=[0, 105]),
            paper_bgcolor='rgba(0,0,0,0)',
            plot_bgcolor='rgba(0,0,0,0)',
            font=dict(color='white')
        )

        st.plotly_chart(fig, use_container_width=True)
    else:
        st.info("👆 Draw a digit on the canvas to see the prediction")

# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #888; font-size: 14px;">
    <p>Built with Streamlit | Model trained on sklearn digits dataset (8x8 images)</p>
    <p>The model is a Multi-Layer Perceptron (MLP) with ~97% accuracy</p>
</div>
""", unsafe_allow_html=True)