MNIST CNN Classifier

A custom Convolutional Neural Network for MNIST digit classification, built with PyTorch and compatible with Hugging Face Transformers.

Model Description

This model implements a CNN architecture specifically designed for MNIST handwritten digit recognition. The model achieves over 98% accuracy on the MNIST test set and is fully compatible with the Hugging Face Transformers ecosystem.

Model Architecture

  • Input: 1x28x28 grayscale images (MNIST digits)
  • Architecture:
    • 2 convolutional blocks (each with 2 conv layers + batch norm + ReLU + max pool + dropout)
    • 2 fully connected layers (with batch norm and dropout)
    • Output layer: 10 classes (digits 0-9)
  • Parameters: ~1.68M trainable parameters
  • Activation: ReLU
  • Normalization: Batch normalization
  • Regularization: Dropout (0.25 for conv layers, 0.5 for fc layers)

Training Details

  • Dataset: MNIST (60,000 training, 10,000 test samples)
  • Optimizer: Adam (lr=0.001)
  • Loss Function: Cross-Entropy Loss
  • Batch Size: 64
  • Epochs: 10
  • Learning Rate Scheduling: ReduceLROnPlateau
  • Data Augmentation: None (basic MNIST preprocessing only)
  • Normalization: MNIST standard (mean=0.1307, std=0.3081)

Performance

  • Test Accuracy: >98%
  • Training Time: ~5 minutes on single GPU
  • Model Size: ~6.7MB (saved weights)

Usage

Using Hugging Face Transformers

from transformers import AutoModel, AutoImageProcessor
import torch
from PIL import Image
import numpy as np

# Load model and processor
model = AutoModel.from_pretrained("your-username/mnist-cnn-classifier")
processor = AutoImageProcessor.from_pretrained("your-username/mnist-cnn-classifier")

# Prepare image
image = Image.open("path/to/mnist_digit.png").convert("L")  # Convert to grayscale
inputs = processor(images=image, return_tensors="pt")

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=-1).item()
    confidence = predictions[0][predicted_class].item()

print(f"Predicted digit: {predicted_class} (confidence: {confidence:.4f})")

Using PyTorch Directly

import torch
from modeling_mnist_cnn import MnistCNN
from configuration_mnist_cnn import MnistCnnConfig
from torchvision import transforms

# Load configuration and model
config = MnistCnnConfig()
model = MnistCNN(config)
model.load_state_dict(torch.load("best_model.pth", map_location="cpu"))
model.eval()

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load and preprocess image
from PIL import Image
image = Image.open("digit.png").convert("L")
input_tensor = transform(image).unsqueeze(0)

# Predict
with torch.no_grad():
    output = model(input_tensor)
    prediction = torch.argmax(output, dim=1).item()
    
print(f"Predicted digit: {prediction}")

Intended Use

This model is designed for:

  • Educational purposes and learning computer vision
  • Benchmarking and comparison with other MNIST models
  • Testing deployment pipelines
  • Demonstrating custom model integration with Hugging Face

Limitations

  • Trained only on MNIST dataset (handwritten digits 0-9)
  • Not suitable for general character recognition
  • Performance may vary on different writing styles not represented in MNIST
  • Input must be 28x28 grayscale images

Ethical Considerations

This model was trained on a standard academic dataset and poses no significant ethical concerns. It should be used responsibly for educational and research purposes.

Training Data

The model was trained on the MNIST dataset, which is freely available for academic and research use. The dataset consists of:

  • 60,000 training images
  • 10,000 test images
  • 28x28 pixel grayscale handwritten digits (0-9)

Technical Details

  • Framework: PyTorch
  • Transformers Compatibility: Yes
  • AutoClass Support: Yes
  • Supported Tasks: Image Classification
  • Input Format: Images (PIL.Image.Image)
  • Output Format: Class labels (0-9)

Model Files

  • pytorch_model.bin: Trained model weights
  • config.json: Model configuration
  • preprocessor_config.json: Image preprocessing configuration
  • modeling_mnist_cnn.py: Model architecture definition
  • configuration_mnist_cnn.py: Configuration class

Citation

If you use this model in your research, please cite:

@misc{mnist-cnn-classifier,
  title={MNIST CNN Classifier},
  author={Your Name},
  year={2024},
  url={https://huggingface.co/your-username/mnist-cnn-classifier}
}

License

This model is released under the MIT License.

Downloads last month
33
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support