|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: sklearn |
|
|
tags: |
|
|
- mnist |
|
|
- image-classification |
|
|
- digits |
|
|
- handwritten |
|
|
- computer-vision |
|
|
- logistic-regression |
|
|
- machine-learning |
|
|
datasets: |
|
|
- ylecun/mnist |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
pipeline_tag: image-classification |
|
|
--- |
|
|
|
|
|
# MNIST Handwritten Digit Classifier |
|
|
|
|
|
A classical machine learning approach to handwritten digit recognition using Logistic Regression on the MNIST dataset. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model classifies 28x28 grayscale images of handwritten digits (0-9) using a simple yet effective Logistic Regression classifier. The project serves as an introduction to image classification and the MNIST dataset. |
|
|
|
|
|
### Intended Uses |
|
|
|
|
|
- **Educational**: Learning image classification fundamentals |
|
|
- **Benchmarking**: Baseline for comparing more complex models |
|
|
- **Research**: Exploring classical ML on image data |
|
|
- **Prototyping**: Quick digit recognition experiments |
|
|
|
|
|
## Training Data |
|
|
|
|
|
**Dataset**: [ylecun/mnist](https://huggingface.co/datasets/ylecun/mnist) |
|
|
|
|
|
| Split | Images | |
|
|
|-------|--------| |
|
|
| Train | 60,000 | |
|
|
| Test | 10,000 | |
|
|
| **Total** | **70,000** | |
|
|
|
|
|
### Data Characteristics |
|
|
|
|
|
| Property | Value | |
|
|
|----------|-------| |
|
|
| Image Size | 28 x 28 pixels | |
|
|
| Channels | 1 (Grayscale) | |
|
|
| Classes | 10 (digits 0-9) | |
|
|
| Pixel Range | 0-255 (raw), 0-1 (normalized) | |
|
|
| Format | PNG/NumPy arrays | |
|
|
|
|
|
### Class Distribution |
|
|
|
|
|
The dataset is relatively balanced across all 10 digit classes. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
### Preprocessing Pipeline |
|
|
|
|
|
``` |
|
|
Raw Image (28x28, uint8) |
|
|
β |
|
|
Normalize to [0, 1] (divide by 255) |
|
|
β |
|
|
Flatten to vector (784 dimensions) |
|
|
β |
|
|
Logistic Regression Classifier |
|
|
β |
|
|
Softmax Probabilities (10 classes) |
|
|
``` |
|
|
|
|
|
### Classifier Configuration |
|
|
|
|
|
```python |
|
|
LogisticRegression( |
|
|
max_iter=100, |
|
|
solver='lbfgs', |
|
|
multi_class='multinomial', |
|
|
n_jobs=-1 |
|
|
) |
|
|
``` |
|
|
|
|
|
| Parameter | Value | Description | |
|
|
|-----------|-------|-------------| |
|
|
| max_iter | 100 | Maximum iterations for convergence | |
|
|
| solver | lbfgs | L-BFGS optimization algorithm | |
|
|
| multi_class | multinomial | True multiclass (not OvR) | |
|
|
| n_jobs | -1 | Use all CPU cores | |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Test Set Results |
|
|
|
|
|
| Metric | Score | |
|
|
|--------|-------| |
|
|
| Accuracy | ~92% | |
|
|
| Macro F1 | ~92% | |
|
|
| Macro Precision | ~92% | |
|
|
| Macro Recall | ~92% | |
|
|
|
|
|
### Per-Class Performance |
|
|
|
|
|
| Digit | Precision | Recall | F1-Score | |
|
|
|-------|-----------|--------|----------| |
|
|
| 0 | ~0.95 | ~0.97 | ~0.96 | |
|
|
| 1 | ~0.95 | ~0.97 | ~0.96 | |
|
|
| 2 | ~0.91 | ~0.89 | ~0.90 | |
|
|
| 3 | ~0.89 | ~0.90 | ~0.90 | |
|
|
| 4 | ~0.92 | ~0.92 | ~0.92 | |
|
|
| 5 | ~0.88 | ~0.87 | ~0.87 | |
|
|
| 6 | ~0.94 | ~0.95 | ~0.94 | |
|
|
| 7 | ~0.93 | ~0.91 | ~0.92 | |
|
|
| 8 | ~0.88 | ~0.87 | ~0.88 | |
|
|
| 9 | ~0.89 | ~0.90 | ~0.90 | |
|
|
|
|
|
*Note: Performance varies slightly between runs* |
|
|
|
|
|
### Common Confusion Pairs |
|
|
|
|
|
- 4 β 9 (similar upper loops) |
|
|
- 3 β 8 (curved shapes) |
|
|
- 5 β 3 (similar strokes) |
|
|
- 7 β 1 (vertical strokes) |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install scikit-learn pandas numpy matplotlib seaborn pillow |
|
|
``` |
|
|
|
|
|
### Load and Preprocess Data |
|
|
|
|
|
```python |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
# Load from Hugging Face |
|
|
df_train = pd.read_parquet("hf://datasets/ylecun/mnist/mnist/train-00000-of-00001.parquet") |
|
|
df_test = pd.read_parquet("hf://datasets/ylecun/mnist/mnist/test-00000-of-00001.parquet") |
|
|
|
|
|
def extract_image(row): |
|
|
"""Extract image as numpy array""" |
|
|
img_data = row['image'] |
|
|
if isinstance(img_data, dict) and 'bytes' in img_data: |
|
|
from io import BytesIO |
|
|
img = Image.open(BytesIO(img_data['bytes'])) |
|
|
return np.array(img) |
|
|
elif isinstance(img_data, Image.Image): |
|
|
return np.array(img_data) |
|
|
return np.array(img_data) |
|
|
|
|
|
# Prepare data |
|
|
X_train = np.array([extract_image(row) for _, row in df_train.iterrows()]) |
|
|
y_train = df_train['label'].values |
|
|
|
|
|
# Normalize and flatten |
|
|
X_train_flat = X_train.astype('float32').reshape(-1, 784) / 255.0 |
|
|
``` |
|
|
|
|
|
### Train Model |
|
|
|
|
|
```python |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
|
|
|
model = LogisticRegression( |
|
|
max_iter=100, |
|
|
solver='lbfgs', |
|
|
multi_class='multinomial', |
|
|
n_jobs=-1 |
|
|
) |
|
|
model.fit(X_train_flat, y_train) |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
import joblib |
|
|
|
|
|
# Load model |
|
|
model = joblib.load('mnist_model.pkl') |
|
|
|
|
|
# Predict single image |
|
|
def predict_digit(image): |
|
|
""" |
|
|
image: 28x28 numpy array or PIL Image |
|
|
returns: predicted digit (0-9) |
|
|
""" |
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
|
|
|
# Preprocess |
|
|
image_flat = image.astype('float32').reshape(1, 784) / 255.0 |
|
|
|
|
|
# Predict |
|
|
prediction = model.predict(image_flat)[0] |
|
|
probabilities = model.predict_proba(image_flat)[0] |
|
|
|
|
|
return prediction, probabilities |
|
|
|
|
|
# Example |
|
|
digit, probs = predict_digit(test_image) |
|
|
print(f"Predicted: {digit} (confidence: {probs[digit]:.2%})") |
|
|
``` |
|
|
|
|
|
### Visualization |
|
|
|
|
|
```python |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.metrics import confusion_matrix |
|
|
import seaborn as sns |
|
|
|
|
|
# Confusion Matrix |
|
|
y_pred = model.predict(X_test_flat) |
|
|
cm = confusion_matrix(y_test, y_pred) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=range(10), yticklabels=range(10)) |
|
|
plt.xlabel('Predicted') |
|
|
plt.ylabel('True') |
|
|
plt.title('Confusion Matrix - MNIST') |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
### Average Digit Visualization |
|
|
|
|
|
```python |
|
|
# Compute mean image per digit |
|
|
fig, axes = plt.subplots(2, 5, figsize=(12, 5)) |
|
|
for digit in range(10): |
|
|
ax = axes[digit // 5, digit % 5] |
|
|
mask = y_train == digit |
|
|
mean_img = X_train[mask].mean(axis=0) |
|
|
ax.imshow(mean_img, cmap='hot') |
|
|
ax.set_title(f'Digit: {digit}') |
|
|
ax.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Simple Model**: Logistic Regression doesn't capture spatial relationships |
|
|
- **No Data Augmentation**: Sensitive to rotation, scaling, translation |
|
|
- **Grayscale Only**: Won't work with color images |
|
|
- **Fixed Size**: Requires exactly 28x28 input |
|
|
- **Clean Data**: Struggles with noisy or poorly centered digits |
|
|
|
|
|
## Comparison with Other Approaches |
|
|
|
|
|
| Model | MNIST Accuracy | |
|
|
|-------|----------------| |
|
|
| **Logistic Regression** | **~92%** | |
|
|
| Random Forest | ~97% | |
|
|
| SVM (RBF kernel) | ~98% | |
|
|
| MLP (2 hidden layers) | ~98% | |
|
|
| CNN (LeNet-5) | ~99% | |
|
|
| Modern CNNs | ~99.7% | |
|
|
|
|
|
## Technical Specifications |
|
|
|
|
|
### Dependencies |
|
|
|
|
|
``` |
|
|
scikit-learn>=1.0.0 |
|
|
pandas>=1.3.0 |
|
|
numpy>=1.20.0 |
|
|
matplotlib>=3.4.0 |
|
|
seaborn>=0.11.0 |
|
|
pillow>=8.0.0 |
|
|
``` |
|
|
|
|
|
### Hardware Requirements |
|
|
|
|
|
| Task | Hardware | Time | |
|
|
|------|----------|------| |
|
|
| Training | CPU | ~2-5 min | |
|
|
| Inference | CPU | < 1ms per image | |
|
|
| Memory | RAM | ~500MB | |
|
|
|
|
|
## Files |
|
|
|
|
|
``` |
|
|
MNIST/ |
|
|
βββ README_HF.md # This model card |
|
|
βββ mnist_exploration.ipynb # Full exploration notebook |
|
|
βββ mnist_model.pkl # Trained model (generated) |
|
|
βββ figures/ # Visualizations (generated) |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@article{lecun1998mnist, |
|
|
title={Gradient-based learning applied to document recognition}, |
|
|
author={LeCun, Yann and Bottou, L{\'e}on and Bengio, Yoshua and Haffner, Patrick}, |
|
|
journal={Proceedings of the IEEE}, |
|
|
volume={86}, |
|
|
number={11}, |
|
|
pages={2278--2324}, |
|
|
year={1998} |
|
|
} |
|
|
|
|
|
@misc{mnist_hf, |
|
|
title={MNIST Dataset}, |
|
|
author={LeCun, Yann and Cortes, Corinna and Burges, Christopher J.C.}, |
|
|
howpublished={Hugging Face Datasets}, |
|
|
url={https://huggingface.co/datasets/ylecun/mnist} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Yann LeCun for creating MNIST |
|
|
- Scikit-learn team for the ML library |
|
|
- Hugging Face for dataset hosting |
|
|
|
|
|
--- |
|
|
|
|
|
## Next Steps |
|
|
|
|
|
For better performance, consider: |
|
|
|
|
|
1. **More Complex Models**: SVM, Random Forest, Neural Networks |
|
|
2. **Deep Learning**: CNNs with PyTorch/TensorFlow |
|
|
3. **Data Augmentation**: Rotation, scaling, elastic deformations |
|
|
4. **Feature Engineering**: HOG, SIFT features |
|
|
5. **Ensemble Methods**: Combine multiple classifiers |
|
|
|