--- datasets: - ylecun/mnist language: - en metrics: - accuracy --- # 🧠 TinyCNN for MNIST (94K params) This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task. This model was train on shifted MNIST handwritten digit dataset The model is optimized to be **small, fast, and easy to deploy**, suitable for both research and educational purposes. --- ## 📌 Model Summary | Attribute | Value | |---------|--------| | Model Name | **TinyCNN** | | Dataset | MNIST (28×28 grayscale digits) | | Total Parameters | ~94,410 | | Architecture | Conv-BN-ReLU ×3 → Global Avg Pool → FC | | Input Shape | (1, 28, 28) | | Output Classes | 10 | | Framework | PyTorch | --- ## 🏗 Architecture Overview **Input**: 1×28×28 **Conv Block 1:** Conv(1→32, 3×3) → BatchNorm → ReLU → MaxPool(2×2) **Conv Block 2:** Conv(32→64, 3×3) → BatchNorm → ReLU → MaxPool(2×2) **Conv Block 3:** Conv(64→128, 3×3) → BatchNorm → ReLU → MaxPool(2×2) **Global Average Pooling** **Fully Connected Layer** → 10 output classes This architecture emphasizes parameter efficiency while maintaining strong representation capability. --- ## ⚙️ Installation ```bash pip install torch torchvision ``` ## 🚀 Load Model From Hub ``` import torch class TinyCNN(nn.Module): """ Tiny CNN for MNIST using Global Avg Pooling. Trainable parameters: 94,410 """ def __init__(self, num_classes=10): super(TinyCNN, self).__init__() # First conv block self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2, 2) # Second conv block self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2, 2) # Third conv block self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2, 2) # Global average pooling self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Final FC (input = 128 channels after GAP) self.fc = nn.Linear(128, num_classes) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.avgpool(x) # (batch, 64, 1, 1) x = x.view(x.size(0), -1) # (batch, 64) x = self.fc(x) # (batch, num_classes) return x model = TinyCNN(num_classes=10) state_dict = torch.hub.load_state_dict_from_url( "https://huggingface.co/FinOS-Internship/ShiftedTinyCNN/TinyCNN_model_acc_98.97.pth" ) model.load_state_dict(state_dict) model.eval() ``` ## 🖼 Example Inference ``` import torch from torchvision import transforms from PIL import Image transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor() ]) img = Image.open("digit.png") x = transform(img).unsqueeze(0) # shape: (1, 1, 28, 28) with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1).item() print("Predicted digit:", pred) ```