File size: 3,340 Bytes
5a1026e 6d74043 1e33668 6d74043 5a1026e 6d74043 5a1026e 6d74043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
---
datasets:
- ylecun/mnist
language:
- en
metrics:
- accuracy
---
# π§ TinyCNN for MNIST (94K params)
This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task.
This model was train on shifted MNIST handwritten digit dataset
The model is optimized to be **small, fast, and easy to deploy**, suitable for both research and educational purposes.
---
## π Model Summary
| Attribute | Value |
|---------|--------|
| Model Name | **TinyCNN** |
| Dataset | MNIST (28Γ28 grayscale digits) |
| Total Parameters | ~94,410 |
| Architecture | Conv-BN-ReLU Γ3 β Global Avg Pool β FC |
| Input Shape | (1, 28, 28) |
| Output Classes | 10 |
| Framework | PyTorch |
---
## π Architecture Overview
**Input**: 1Γ28Γ28
**Conv Block 1:** Conv(1β32, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
**Conv Block 2:** Conv(32β64, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
**Conv Block 3:** Conv(64β128, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2)
**Global Average Pooling**
**Fully Connected Layer** β 10 output classes
This architecture emphasizes parameter efficiency while maintaining strong representation capability.
---
## βοΈ Installation
```bash
pip install torch torchvision
```
## π Load Model From Hub
```
import torch
class TinyCNN(nn.Module):
"""
Tiny CNN for MNIST using Global Avg Pooling.
Trainable parameters: 94,410
"""
def __init__(self, num_classes=10):
super(TinyCNN, self).__init__()
# First conv block
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second conv block
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Third conv block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final FC (input = 128 channels after GAP)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.avgpool(x) # (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # (batch, 64)
x = self.fc(x) # (batch, num_classes)
return x
model = TinyCNN(num_classes=10)
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth"
)
model.load_state_dict(state_dict)
model.eval()
```
## πΌ Example Inference
```
import torch
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor()
])
img = Image.open("digit.png")
x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
print("Predicted digit:", pred)
``` |