| | --- |
| | datasets: |
| | - ylecun/mnist |
| | --- |
| | # MNIST CNN Classifier |
| |
|
| | A simple CNN for handwritten digit classification, trained on the MNIST dataset. |
| |
|
| | # Model Details |
| | - Architecture: 2 conv layers and 2 fully connected layers |
| | - Accuracy: 99.4% on test set |
| | - Pytorch |
| |
|
| | # Usage |
| | ```python |
| | import torch |
| | from torch import nn |
| | |
| | #Define the architecture |
| | class CNN(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
| | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
| | self.fc1 = nn.Linear(64 * 7 * 7, 128) |
| | self.fc2 = nn.Linear(128, 10) |
| | self.relu = nn.ReLU() |
| | |
| | def forward(self, x): |
| | x = self.pool(self.relu(self.conv1(x))) |
| | x = self.pool(self.relu(self.conv2(x))) |
| | x = x.view(-1, 64 * 7 * 7) |
| | x = self.relu(self.fc1(x)) |
| | return self.fc2(x) |
| | |
| | #Load model |
| | model = CNN() |
| | model.load_state_dict(torch.load("mnist_cnn.pth")) |
| | model.eval() |
| | ``` |
| |
|
| | # Showcase |
| |  |
| |
|