Handwritten Digit Classifier

A PyTorch image classification model that recognizes handwritten digits (0–9), built on a pretrained ResNet-18 backbone (ImageNet weights) fine-tuned on a combined MNIST + EMNIST dataset with aggressive data augmentation. Achieves 99.46% accuracy on the combined test set.


Model Details

Property Value
Architecture ResNet-18 (pretrained on ImageNet)
Framework PyTorch
Task Image Classification (10 classes, digits 0–9)
Input Size 32 Γ— 32 (grayscale, converted to 3-channel)
Output Softmax probabilities over digits 0–9
Test Accuracy 99.46%
Training Device CUDA (GPU)
Epochs 7
Batch Size 256
Optimizer Adam (differential learning rates)
Loss Function CrossEntropyLoss
LR Scheduler StepLR (step=2, gamma=0.5)

Architecture

The model uses a ResNet-18 backbone pretrained on ImageNet, with the default classification head replaced by a custom fully-connected head:

ResNet-18 Backbone (pretrained on ImageNet1K)
        ↓
  Linear(512 β†’ 128)
        ↓
      ReLU()
        ↓
    Dropout(0.3)
        ↓
  Linear(128 β†’ 10)
        ↓
  Softmax (at inference)

Differential learning rates were used to preserve pretrained features while allowing the new head to learn faster:

  • Pretrained backbone layers: lr = 0.0001
  • New classification head (last 4 param groups): lr = 0.001

The dropout layer (p=0.3) reduces overfitting given the simplicity of digit images relative to the model's capacity.


Dataset

The model was trained on a combined MNIST + EMNIST (digits split) dataset for greater diversity and robustness.

MNIST

Property Value
Classes 10 (digits 0–9)
Training set 60,000 grayscale images
Test set 10,000 grayscale images
Image size 28 Γ— 28 pixels
Source yann.lecun.com/exdb/mnist

EMNIST (digits split)

Property Value
Classes 10 (digits 0–9)
Training set 240,000 grayscale images
Test set 40,000 grayscale images
Image size 28 Γ— 28 pixels
Source NIST Special Database 19

Combined total: 300,000 training images and 50,000 test images.


Training

The model was trained for 7 epochs on CUDA with a StepLR scheduler (halving LR every 2 epochs). Loss decreased consistently across all epochs.

Epoch Loss
1 0.1732
2 0.0635
3 0.0446
4 0.0409
5 0.0340
6 0.0307
7 0.0279

Final Test Accuracy: 99.46%


Data Augmentation

Aggressive augmentation was applied during training to improve generalization to real-world handwriting styles:

Augmentation Parameters
Random Rotation Β±15Β°
Random Affine (translate) Β±15% horizontal and vertical
Random Affine (shear) 10Β°
Random Perspective distortion scale 0.3, p=0.3
Color Jitter brightness Β±0.3, contrast Β±0.3
Normalization mean (0.5, 0.5, 0.5), std (0.5, 0.5, 0.5)

No augmentation was applied to the test set (only resize + normalize).


Preprocessing

At inference, input images go through the following pipeline:

  1. Convert to grayscale
  2. Invert colors (white background β†’ black background to match MNIST format)
  3. Resize to 32 Γ— 32
  4. Convert to 3-channel (grayscale replicated across RGB channels for ResNet compatibility)
  5. Normalize with mean (0.5, 0.5, 0.5) and std (0.5, 0.5, 0.5)

Usage

import torch
import torch.nn as nn
from torchvision import transforms, models
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np

# Load model
model = models.resnet18(weights=None)
model.fc = nn.Sequential(
    nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10)
)

weights_path = hf_hub_download(
    repo_id="AdityaManojShinde/handwritten_digit_classifier",
    filename="mnist_model.pth"
)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()

# Preprocessing pipeline
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Inference
image = Image.open("your_digit.png").convert("L")
img_array = 255 - np.array(image)   # invert: white bg β†’ black bg
image = Image.fromarray(img_array)
img_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(img_tensor)
    probs = torch.nn.functional.softmax(output, dim=1)[0]
    predicted = probs.argmax().item()

print(f"Predicted digit: {predicted} ({probs[predicted]*100:.1f}% confidence)")

Limitations

  • Works best with centered, clearly written single digits on a plain background.
  • Not suitable for multi-digit recognition or digit detection in natural scenes.
  • May struggle with highly stylized or non-standard digit handwriting not represented in MNIST/EMNIST.

License

This model is released under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train AdityaManojShinde/handwritten_digit_classifier