AdityaManojShinde's picture
updated model card
ca7a55d verified
---
language: en
tags:
- image-classification
- mnist
- emnist
- digit-recognition
- pytorch
- resnet
license: mit
datasets:
- mnist
- emnist
pipeline_tag: image-classification
---
# Handwritten Digit Classifier
A PyTorch image classification model that recognizes handwritten digits (0–9), built on a **pretrained ResNet-18** backbone (ImageNet weights) fine-tuned on a combined **MNIST + EMNIST** dataset with aggressive data augmentation. Achieves **99.46% accuracy** on the combined test set.
---
## Model Details
| Property | Value |
|-----------------------|-------------------------------------------------|
| **Architecture** | ResNet-18 (pretrained on ImageNet) |
| **Framework** | PyTorch |
| **Task** | Image Classification (10 classes, digits 0–9) |
| **Input Size** | 32 Γ— 32 (grayscale, converted to 3-channel) |
| **Output** | Softmax probabilities over digits 0–9 |
| **Test Accuracy** | **99.46%** |
| **Training Device** | CUDA (GPU) |
| **Epochs** | 7 |
| **Batch Size** | 256 |
| **Optimizer** | Adam (differential learning rates) |
| **Loss Function** | CrossEntropyLoss |
| **LR Scheduler** | StepLR (step=2, gamma=0.5) |
---
## Architecture
The model uses a **ResNet-18** backbone pretrained on ImageNet, with the default classification head replaced by a custom fully-connected head:
```
ResNet-18 Backbone (pretrained on ImageNet1K)
↓
Linear(512 β†’ 128)
↓
ReLU()
↓
Dropout(0.3)
↓
Linear(128 β†’ 10)
↓
Softmax (at inference)
```
**Differential learning rates** were used to preserve pretrained features while allowing the new head to learn faster:
- Pretrained backbone layers: `lr = 0.0001`
- New classification head (last 4 param groups): `lr = 0.001`
The dropout layer (p=0.3) reduces overfitting given the simplicity of digit images relative to the model's capacity.
---
## Dataset
The model was trained on a **combined MNIST + EMNIST (digits split)** dataset for greater diversity and robustness.
### MNIST
| Property | Value |
|------------------|----------------------------|
| **Classes** | 10 (digits 0–9) |
| **Training set** | 60,000 grayscale images |
| **Test set** | 10,000 grayscale images |
| **Image size** | 28 Γ— 28 pixels |
| **Source** | [yann.lecun.com/exdb/mnist](http://yann.lecun.com/exdb/mnist/) |
### EMNIST (digits split)
| Property | Value |
|------------------|----------------------------|
| **Classes** | 10 (digits 0–9) |
| **Training set** | 240,000 grayscale images |
| **Test set** | 40,000 grayscale images |
| **Image size** | 28 Γ— 28 pixels |
| **Source** | [NIST Special Database 19](https://www.nist.gov/itl/products-and-services/emnist-dataset) |
**Combined total:** 300,000 training images and 50,000 test images.
---
## Training
The model was trained for **7 epochs** on CUDA with a StepLR scheduler (halving LR every 2 epochs). Loss decreased consistently across all epochs.
| Epoch | Loss |
|-------|--------|
| 1 | 0.1732 |
| 2 | 0.0635 |
| 3 | 0.0446 |
| 4 | 0.0409 |
| 5 | 0.0340 |
| 6 | 0.0307 |
| 7 | 0.0279 |
**Final Test Accuracy: 99.46%**
---
## Data Augmentation
Aggressive augmentation was applied during training to improve generalization to real-world handwriting styles:
| Augmentation | Parameters |
|-------------------------|-----------------------------------------|
| Random Rotation | Β±15Β° |
| Random Affine (translate)| Β±15% horizontal and vertical |
| Random Affine (shear) | 10Β° |
| Random Perspective | distortion scale 0.3, p=0.3 |
| Color Jitter | brightness Β±0.3, contrast Β±0.3 |
| Normalization | mean (0.5, 0.5, 0.5), std (0.5, 0.5, 0.5) |
No augmentation was applied to the test set (only resize + normalize).
---
## Preprocessing
At inference, input images go through the following pipeline:
1. Convert to **grayscale**
2. **Invert** colors (white background β†’ black background to match MNIST format)
3. **Resize** to 32 Γ— 32
4. Convert to **3-channel** (grayscale replicated across RGB channels for ResNet compatibility)
5. **Normalize** with mean `(0.5, 0.5, 0.5)` and std `(0.5, 0.5, 0.5)`
---
## Usage
```python
import torch
import torch.nn as nn
from torchvision import transforms, models
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
# Load model
model = models.resnet18(weights=None)
model.fc = nn.Sequential(
nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10)
)
weights_path = hf_hub_download(
repo_id="AdityaManojShinde/handwritten_digit_classifier",
filename="mnist_model.pth"
)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
# Preprocessing pipeline
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Inference
image = Image.open("your_digit.png").convert("L")
img_array = 255 - np.array(image) # invert: white bg β†’ black bg
image = Image.fromarray(img_array)
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
probs = torch.nn.functional.softmax(output, dim=1)[0]
predicted = probs.argmax().item()
print(f"Predicted digit: {predicted} ({probs[predicted]*100:.1f}% confidence)")
```
---
## Limitations
- Works best with **centered, clearly written** single digits on a plain background.
- Not suitable for multi-digit recognition or digit detection in natural scenes.
- May struggle with highly stylized or non-standard digit handwriting not represented in MNIST/EMNIST.
---
## License
This model is released under the [MIT License](https://opensource.org/licenses/MIT).