Vincimus Claude Opus 4.5 commited on
Commit
cd8e368
·
1 Parent(s): 36fc4af

Add handwritten digit recognizer with MLP classifier

Browse files

- Drawable canvas for user input (streamlit-drawable-canvas)
- MLP model trained on sklearn digits dataset (97.78% accuracy)
- Real-time prediction with confidence visualization
- Plotly bar chart showing probabilities for all 10 digits

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Dockerfile CHANGED
@@ -6,6 +6,8 @@ RUN apt-get update && apt-get install -y \
6
  build-essential \
7
  curl \
8
  git \
 
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt ./
@@ -17,4 +19,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
6
  build-essential \
7
  curl \
8
  git \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  COPY requirements.txt ./
 
19
 
20
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
21
 
22
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,19 +1,39 @@
1
  ---
2
- title: Another Demo
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
  - streamlit
 
 
10
  pinned: false
11
- short_description: Streamlit template space
12
  ---
13
 
14
- # Welcome to Streamlit!
15
 
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Digit Recognizer
3
+ emoji: ✏️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
  - streamlit
10
+ - machine-learning
11
+ - digit-recognition
12
  pinned: false
13
+ short_description: Draw a digit and watch AI recognize it!
14
  ---
15
 
16
+ # ✏️ Handwritten Digit Recognizer
17
 
18
+ An interactive machine learning demo where you can draw digits and watch an AI model recognize them in real-time!
19
 
20
+ ## Features
21
+
22
+ - **Drawing Canvas**: Draw digits (0-9) with your mouse or touchscreen
23
+ - **Real-time Prediction**: See the model's prediction instantly
24
+ - **Confidence Visualization**: View probability scores for all 10 digits
25
+ - **Easy Reset**: Clear the canvas and try again
26
+
27
+ ## How It Works
28
+
29
+ 1. Draw a digit (0-9) on the canvas
30
+ 2. Click "Predict" to see the model's prediction
31
+ 3. View the confidence chart showing probabilities for each digit
32
+ 4. Click "Clear Canvas" to draw another digit
33
+
34
+ ## Technical Details
35
+
36
+ - **Model**: MLP Neural Network trained on sklearn's digits dataset
37
+ - **Input**: 8x8 grayscale images (scaled from canvas)
38
+ - **Accuracy**: ~97% on test set
39
+ - **Framework**: Streamlit with streamlit-drawable-canvas
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
  pandas
3
- streamlit
 
 
 
 
 
 
1
+ streamlit
2
  pandas
3
+ numpy
4
+ scikit-learn
5
+ joblib
6
+ plotly
7
+ streamlit-drawable-canvas
8
+ Pillow
src/model/digit_classifier.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2182dbf208c3ea6c5e2384366e5496bea047e4ad286a2e300afb14f431c52376
3
+ size 560992
src/streamlit_app.py CHANGED
@@ -1,40 +1,136 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ from PIL import Image
4
+ import joblib
5
+ import os
6
+ import plotly.graph_objects as go
7
+ from streamlit_drawable_canvas import st_canvas
8
+
9
+ # Page config
10
+ st.set_page_config(
11
+ page_title="Digit Recognizer",
12
+ page_icon="✏️",
13
+ layout="centered"
14
+ )
15
+
16
+ # Load model
17
+ @st.cache_resource
18
+ def load_model():
19
+ model_path = os.path.join(os.path.dirname(__file__), "model", "digit_classifier.joblib")
20
+ return joblib.load(model_path)
21
+
22
+ model = load_model()
23
+
24
+ # Title and instructions
25
+ st.title("✏️ Handwritten Digit Recognizer")
26
+ st.markdown("""
27
+ Draw a digit (0-9) in the canvas below and click **Predict** to see what the AI thinks it is!
28
+
29
+ *Tip: Draw the digit large and centered for best results.*
30
+ """)
31
+
32
+ # Create two columns for layout
33
+ col1, col2 = st.columns([1, 1])
34
+
35
+ with col1:
36
+ st.subheader("Draw Here")
37
+
38
+ # Drawing canvas
39
+ canvas_result = st_canvas(
40
+ fill_color="black",
41
+ stroke_width=20,
42
+ stroke_color="white",
43
+ background_color="black",
44
+ height=280,
45
+ width=280,
46
+ drawing_mode="freedraw",
47
+ key="canvas",
48
+ )
49
+
50
+ # Buttons
51
+ btn_col1, btn_col2 = st.columns(2)
52
+ with btn_col1:
53
+ predict_btn = st.button("🔮 Predict", type="primary", use_container_width=True)
54
+ with btn_col2:
55
+ if st.button("🗑️ Clear Canvas", use_container_width=True):
56
+ st.rerun()
57
+
58
+ with col2:
59
+ st.subheader("Prediction")
60
+
61
+ # Placeholder for results
62
+ result_container = st.container()
63
+
64
+ if predict_btn and canvas_result.image_data is not None:
65
+ # Process the canvas image
66
+ img_array = canvas_result.image_data
67
+
68
+ # Check if canvas has any drawing (not all black)
69
+ if np.sum(img_array[:, :, :3]) > 0:
70
+ # Convert to PIL Image and process
71
+ img = Image.fromarray(img_array.astype('uint8'), 'RGBA')
72
+ img = img.convert('L') # Convert to grayscale
73
+
74
+ # Resize to 8x8 (sklearn digits format)
75
+ img = img.resize((8, 8), Image.Resampling.LANCZOS)
76
+
77
+ # Convert to numpy array and normalize to 0-16 range (sklearn format)
78
+ img_array = np.array(img)
79
+ img_array = (img_array / 255.0) * 16
80
+
81
+ # Flatten for prediction
82
+ img_flat = img_array.flatten().reshape(1, -1)
83
+
84
+ # Get prediction and probabilities
85
+ prediction = model.predict(img_flat)[0]
86
+ probabilities = model.predict_proba(img_flat)[0]
87
+
88
+ with result_container:
89
+ # Display large prediction
90
+ st.markdown(f"""
91
+ <div style="text-align: center; padding: 20px; background-color: #1e1e1e; border-radius: 10px; margin-bottom: 20px;">
92
+ <h1 style="font-size: 72px; margin: 0; color: #4CAF50;">{prediction}</h1>
93
+ <p style="font-size: 18px; color: #888;">Confidence: {probabilities[prediction]*100:.1f}%</p>
94
+ </div>
95
+ """, unsafe_allow_html=True)
96
+
97
+ # Probability chart
98
+ st.subheader("Confidence Scores")
99
+
100
+ # Create horizontal bar chart with Plotly
101
+ fig = go.Figure(go.Bar(
102
+ x=probabilities * 100,
103
+ y=[str(i) for i in range(10)],
104
+ orientation='h',
105
+ marker_color=['#4CAF50' if i == prediction else '#2196F3' for i in range(10)],
106
+ text=[f'{p*100:.1f}%' for p in probabilities],
107
+ textposition='outside'
108
+ ))
109
+
110
+ fig.update_layout(
111
+ xaxis_title="Confidence (%)",
112
+ yaxis_title="Digit",
113
+ height=400,
114
+ margin=dict(l=20, r=20, t=20, b=40),
115
+ xaxis=dict(range=[0, 105]),
116
+ paper_bgcolor='rgba(0,0,0,0)',
117
+ plot_bgcolor='rgba(0,0,0,0)',
118
+ font=dict(color='white')
119
+ )
120
+
121
+ st.plotly_chart(fig, use_container_width=True)
122
+ else:
123
+ with result_container:
124
+ st.info("👆 Draw a digit on the canvas first!")
125
+ else:
126
+ with result_container:
127
+ st.info("👆 Draw a digit on the canvas and click **Predict**")
128
 
129
+ # Footer
130
+ st.markdown("---")
131
+ st.markdown("""
132
+ <div style="text-align: center; color: #888; font-size: 14px;">
133
+ <p>Built with Streamlit | Model trained on sklearn digits dataset (8x8 images)</p>
134
+ <p>The model is a Multi-Layer Perceptron (MLP) with ~97% accuracy</p>
135
+ </div>
136
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a digit classifier on sklearn's digits dataset.
3
+ Run this script locally to generate the model file.
4
+
5
+ Usage:
6
+ python train_model.py
7
+ """
8
+
9
+ import os
10
+ from sklearn.datasets import load_digits
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.neural_network import MLPClassifier
13
+ from sklearn.metrics import accuracy_score, classification_report
14
+ import joblib
15
+
16
+
17
+ def train_digit_classifier():
18
+ """Train an MLP classifier on the sklearn digits dataset (8x8 images)."""
19
+
20
+ print("Loading digits dataset...")
21
+ digits = load_digits()
22
+ X, y = digits.data, digits.target
23
+
24
+ print(f"Dataset shape: {X.shape}")
25
+ print(f"Number of classes: {len(set(y))}")
26
+ print(f"Image size: 8x8 (64 features)")
27
+
28
+ # Split data
29
+ X_train, X_test, y_train, y_test = train_test_split(
30
+ X, y, test_size=0.2, random_state=42, stratify=y
31
+ )
32
+
33
+ print(f"\nTraining samples: {len(X_train)}")
34
+ print(f"Test samples: {len(X_test)}")
35
+
36
+ # Train MLP classifier
37
+ print("\nTraining MLP classifier...")
38
+ model = MLPClassifier(
39
+ hidden_layer_sizes=(128, 64),
40
+ activation='relu',
41
+ max_iter=500,
42
+ random_state=42,
43
+ verbose=True
44
+ )
45
+
46
+ model.fit(X_train, y_train)
47
+
48
+ # Evaluate
49
+ y_pred = model.predict(X_test)
50
+ accuracy = accuracy_score(y_test, y_pred)
51
+
52
+ print(f"\n{'='*50}")
53
+ print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
54
+ print(f"{'='*50}")
55
+ print("\nClassification Report:")
56
+ print(classification_report(y_test, y_pred))
57
+
58
+ # Save model
59
+ model_dir = os.path.join(os.path.dirname(__file__), "src", "model")
60
+ os.makedirs(model_dir, exist_ok=True)
61
+ model_path = os.path.join(model_dir, "digit_classifier.joblib")
62
+
63
+ joblib.dump(model, model_path)
64
+ print(f"\nModel saved to: {model_path}")
65
+
66
+ # Check file size
67
+ file_size = os.path.getsize(model_path) / 1024
68
+ print(f"Model file size: {file_size:.2f} KB")
69
+
70
+ return model
71
+
72
+
73
+ if __name__ == "__main__":
74
+ train_digit_classifier()