ylecun/mnist
Viewer β’ Updated β’ 70k β’ 78.6k β’ 243
This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task. This model was train on shifted MNIST handwritten digit dataset The model is optimized to be small, fast, and easy to deploy, suitable for both research and educational purposes.
| Attribute | Value |
|---|---|
| Model Name | TinyCNN |
| Dataset | MNIST (28Γ28 grayscale digits) |
| Total Parameters | ~94,410 |
| Architecture | Conv-BN-ReLU Γ3 β Global Avg Pool β FC |
| Input Shape | (1, 28, 28) |
| Output Classes | 10 |
| Framework | PyTorch |
Input: 1Γ28Γ28
Conv Block 1: Conv(1β32, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Conv Block 2: Conv(32β64, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Conv Block 3: Conv(64β128, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
Global Average Pooling
Fully Connected Layer β 10 output classes
This architecture emphasizes parameter efficiency while maintaining strong representation capability.
pip install torch torchvision
import torch
class TinyCNN(nn.Module):
"""
Tiny CNN for MNIST using Global Avg Pooling.
Trainable parameters: 94,410
"""
def __init__(self, num_classes=10):
super(TinyCNN, self).__init__()
# First conv block
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second conv block
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Third conv block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final FC (input = 128 channels after GAP)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.avgpool(x) # (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # (batch, 64)
x = self.fc(x) # (batch, num_classes)
return x
model = TinyCNN(num_classes=10)
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
)
model.load_state_dict(state_dict)
model.eval()
import torch
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor()
])
img = Image.open("digit.png")
x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted digit:", pred)