MNIST CNN Classifier
A custom Convolutional Neural Network for MNIST digit classification, built with PyTorch and compatible with Hugging Face Transformers.
Model Description
This model implements a CNN architecture specifically designed for MNIST handwritten digit recognition. The model achieves over 98% accuracy on the MNIST test set and is fully compatible with the Hugging Face Transformers ecosystem.
Model Architecture
- Input: 1x28x28 grayscale images (MNIST digits)
- Architecture:
- 2 convolutional blocks (each with 2 conv layers + batch norm + ReLU + max pool + dropout)
- 2 fully connected layers (with batch norm and dropout)
- Output layer: 10 classes (digits 0-9)
- Parameters: ~1.68M trainable parameters
- Activation: ReLU
- Normalization: Batch normalization
- Regularization: Dropout (0.25 for conv layers, 0.5 for fc layers)
Training Details
- Dataset: MNIST (60,000 training, 10,000 test samples)
- Optimizer: Adam (lr=0.001)
- Loss Function: Cross-Entropy Loss
- Batch Size: 64
- Epochs: 10
- Learning Rate Scheduling: ReduceLROnPlateau
- Data Augmentation: None (basic MNIST preprocessing only)
- Normalization: MNIST standard (mean=0.1307, std=0.3081)
Performance
- Test Accuracy: >98%
- Training Time: ~5 minutes on single GPU
- Model Size: ~6.7MB (saved weights)
Usage
Using Hugging Face Transformers
from transformers import AutoModel, AutoImageProcessor
import torch
from PIL import Image
import numpy as np
# Load model and processor
model = AutoModel.from_pretrained("your-username/mnist-cnn-classifier")
processor = AutoImageProcessor.from_pretrained("your-username/mnist-cnn-classifier")
# Prepare image
image = Image.open("path/to/mnist_digit.png").convert("L") # Convert to grayscale
inputs = processor(images=image, return_tensors="pt")
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
confidence = predictions[0][predicted_class].item()
print(f"Predicted digit: {predicted_class} (confidence: {confidence:.4f})")
Using PyTorch Directly
import torch
from modeling_mnist_cnn import MnistCNN
from configuration_mnist_cnn import MnistCnnConfig
from torchvision import transforms
# Load configuration and model
config = MnistCnnConfig()
model = MnistCNN(config)
model.load_state_dict(torch.load("best_model.pth", map_location="cpu"))
model.eval()
# Define transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load and preprocess image
from PIL import Image
image = Image.open("digit.png").convert("L")
input_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(input_tensor)
prediction = torch.argmax(output, dim=1).item()
print(f"Predicted digit: {prediction}")
Intended Use
This model is designed for:
- Educational purposes and learning computer vision
- Benchmarking and comparison with other MNIST models
- Testing deployment pipelines
- Demonstrating custom model integration with Hugging Face
Limitations
- Trained only on MNIST dataset (handwritten digits 0-9)
- Not suitable for general character recognition
- Performance may vary on different writing styles not represented in MNIST
- Input must be 28x28 grayscale images
Ethical Considerations
This model was trained on a standard academic dataset and poses no significant ethical concerns. It should be used responsibly for educational and research purposes.
Training Data
The model was trained on the MNIST dataset, which is freely available for academic and research use. The dataset consists of:
- 60,000 training images
- 10,000 test images
- 28x28 pixel grayscale handwritten digits (0-9)
Technical Details
- Framework: PyTorch
- Transformers Compatibility: Yes
- AutoClass Support: Yes
- Supported Tasks: Image Classification
- Input Format: Images (PIL.Image.Image)
- Output Format: Class labels (0-9)
Model Files
pytorch_model.bin: Trained model weightsconfig.json: Model configurationpreprocessor_config.json: Image preprocessing configurationmodeling_mnist_cnn.py: Model architecture definitionconfiguration_mnist_cnn.py: Configuration class
Citation
If you use this model in your research, please cite:
@misc{mnist-cnn-classifier,
title={MNIST CNN Classifier},
author={Your Name},
year={2024},
url={https://huggingface.co/your-username/mnist-cnn-classifier}
}
License
This model is released under the MIT License.
- Downloads last month
- 33