|
|
--- |
|
|
datasets: |
|
|
- ylecun/mnist |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
--- |
|
|
# π§ TinyCNN for MNIST (94K params) |
|
|
|
|
|
This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task. |
|
|
This model was train on shifted MNIST handwritten digit dataset |
|
|
The model is optimized to be **small, fast, and easy to deploy**, suitable for both research and educational purposes. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Model Summary |
|
|
|
|
|
| Attribute | Value | |
|
|
|---------|--------| |
|
|
| Model Name | **TinyCNN** | |
|
|
| Dataset | MNIST (28Γ28 grayscale digits) | |
|
|
| Total Parameters | ~94,410 | |
|
|
| Architecture | Conv-BN-ReLU Γ3 β Global Avg Pool β FC | |
|
|
| Input Shape | (1, 28, 28) | |
|
|
| Output Classes | 10 | |
|
|
| Framework | PyTorch | |
|
|
|
|
|
--- |
|
|
|
|
|
## π Architecture Overview |
|
|
|
|
|
**Input**: 1Γ28Γ28 |
|
|
|
|
|
|
|
|
**Conv Block 1:** Conv(1β32, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2) |
|
|
|
|
|
|
|
|
**Conv Block 2:** Conv(32β64, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2) |
|
|
|
|
|
**Conv Block 3:** Conv(64β128, 3Γ3) β BatchNorm β ReLU β MaxPool(2Γ2) |
|
|
|
|
|
**Global Average Pooling** |
|
|
|
|
|
**Fully Connected Layer** β 10 output classes |
|
|
|
|
|
|
|
|
This architecture emphasizes parameter efficiency while maintaining strong representation capability. |
|
|
|
|
|
--- |
|
|
|
|
|
## βοΈ Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision |
|
|
``` |
|
|
## π Load Model From Hub |
|
|
``` |
|
|
import torch |
|
|
|
|
|
class TinyCNN(nn.Module): |
|
|
""" |
|
|
Tiny CNN for MNIST using Global Avg Pooling. |
|
|
|
|
|
Trainable parameters: 94,410 |
|
|
""" |
|
|
|
|
|
def __init__(self, num_classes=10): |
|
|
super(TinyCNN, self).__init__() |
|
|
|
|
|
# First conv block |
|
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(32) |
|
|
self.pool1 = nn.MaxPool2d(2, 2) |
|
|
|
|
|
# Second conv block |
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(64) |
|
|
self.pool2 = nn.MaxPool2d(2, 2) |
|
|
|
|
|
# Third conv block |
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
|
self.bn3 = nn.BatchNorm2d(128) |
|
|
self.pool3 = nn.MaxPool2d(2, 2) |
|
|
|
|
|
# Global average pooling |
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
|
|
# Final FC (input = 128 channels after GAP) |
|
|
self.fc = nn.Linear(128, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.pool1(F.relu(self.bn1(self.conv1(x)))) |
|
|
x = self.pool2(F.relu(self.bn2(self.conv2(x)))) |
|
|
x = self.pool3(F.relu(self.bn3(self.conv3(x)))) |
|
|
x = self.avgpool(x) # (batch, 64, 1, 1) |
|
|
x = x.view(x.size(0), -1) # (batch, 64) |
|
|
x = self.fc(x) # (batch, num_classes) |
|
|
return x |
|
|
|
|
|
model = TinyCNN(num_classes=10) |
|
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
"https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth" |
|
|
) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
## πΌ Example Inference |
|
|
``` |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Grayscale(), |
|
|
transforms.Resize((28, 28)), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
img = Image.open("digit.png") |
|
|
x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
pred = logits.argmax(dim=1).item() |
|
|
|
|
|
print("Predicted digit:", pred) |
|
|
``` |