ShiftedTinyCNN / README.md
felix2703's picture
Update README
5a1026e verified
---
datasets:
- ylecun/mnist
language:
- en
metrics:
- accuracy
---
# 🧠 TinyCNN for MNIST (94K params)
This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task.
This model was train on shifted MNIST handwritten digit dataset
The model is optimized to be **small, fast, and easy to deploy**, suitable for both research and educational purposes.
---
## πŸ“Œ Model Summary
| Attribute | Value |
|---------|--------|
| Model Name | **TinyCNN** |
| Dataset | MNIST (28Γ—28 grayscale digits) |
| Total Parameters | ~94,410 |
| Architecture | Conv-BN-ReLU Γ—3 β†’ Global Avg Pool β†’ FC |
| Input Shape | (1, 28, 28) |
| Output Classes | 10 |
| Framework | PyTorch |
---
## πŸ— Architecture Overview
**Input**: 1Γ—28Γ—28
**Conv Block 1:** Conv(1β†’32, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)
**Conv Block 2:** Conv(32β†’64, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)
**Conv Block 3:** Conv(64β†’128, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)
**Global Average Pooling**
**Fully Connected Layer** β†’ 10 output classes
This architecture emphasizes parameter efficiency while maintaining strong representation capability.
---
## βš™οΈ Installation
```bash
pip install torch torchvision
```
## πŸš€ Load Model From Hub
```
import torch
class TinyCNN(nn.Module):
"""
Tiny CNN for MNIST using Global Avg Pooling.
Trainable parameters: 94,410
"""
def __init__(self, num_classes=10):
super(TinyCNN, self).__init__()
# First conv block
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second conv block
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Third conv block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final FC (input = 128 channels after GAP)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.avgpool(x) # (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # (batch, 64)
x = self.fc(x) # (batch, num_classes)
return x
model = TinyCNN(num_classes=10)
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
)
model.load_state_dict(state_dict)
model.eval()
```
## πŸ–Ό Example Inference
```
import torch
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor()
])
img = Image.open("digit.png")
x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted digit:", pred)
```