ShiftedTinyCNN / README.md
felix2703's picture
Upload checkpoint and readme file
6d74043 verified
|
raw
history blame
1.95 kB

🧠 TinyCNN for MNIST (94K params)

This repository contains a lightweight Convolutional Neural Network (CNN) designed for the MNIST handwritten digit classification task.
The model is optimized to be small, fast, and easy to deploy, suitable for both research and educational purposes.


πŸ“Œ Model Summary

Attribute Value
Model Name TinyCNN
Dataset MNIST (28Γ—28 grayscale digits)
Total Parameters ~94,410
Architecture Conv-BN-ReLU Γ—3 β†’ Global Avg Pool β†’ FC
Input Shape (1, 28, 28)
Output Classes 10
Framework PyTorch

πŸ— Architecture Overview

Input: 1Γ—28Γ—28

Conv Block 1: Conv(1β†’32, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Conv Block 2: Conv(32β†’64, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Conv Block 3: Conv(64β†’128, 3Γ—3) β†’ BatchNorm β†’ ReLU β†’ MaxPool(2Γ—2)

Global Average Pooling

Fully Connected Layer β†’ 10 output classes

This architecture emphasizes parameter efficiency while maintaining strong representation capability.


βš™οΈ Installation

pip install torch torchvision

πŸš€ Load Model From Hub

import torch
from model import TinyCNN  # Ensure this file is included in your repo

model = TinyCNN(num_classes=10)

state_dict = torch.hub.load_state_dict_from_url(
    "https://huggingface.co/<your-username>/<your-model-repo>/resolve/main/tinycnn_mnist.pth"
)
model.load_state_dict(state_dict)
model.eval()

πŸ–Ό Example Inference

import torch
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

img = Image.open("digit.png")
x = transform(img).unsqueeze(0)  # shape: (1, 1, 28, 28)

with torch.no_grad():
    logits = model(x)
    pred = logits.argmax(dim=1).item()

print("Predicted digit:", pred)