Spaces:
Sleeping
Sleeping
| """ | |
| Shifted MNIST CNN Model Architectures | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CNNModel(nn.Module): | |
| """ | |
| CNN Model for MNIST digit classification with shifted labels | |
| Architecture: Conv-BN-ReLU-Pool x3 + FC-Dropout x2 + FC | |
| Trainable parameters: 817,354 | |
| """ | |
| def __init__(self, num_classes=10, dropout_rate=0.5): | |
| super(CNNModel, self).__init__() | |
| # First convolutional block | |
| self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(32) | |
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Second convolutional block | |
| self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Third convolutional block | |
| self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm2d(128) | |
| self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.flattened_size = 128 * 3 * 3 | |
| # Fully connected layers with dropout | |
| self.fc1 = nn.Linear(self.flattened_size, 512) | |
| self.dropout1 = nn.Dropout(dropout_rate) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.dropout2 = nn.Dropout(dropout_rate) | |
| self.fc3 = nn.Linear(256, num_classes) | |
| def forward(self, x): | |
| """Forward pass through the network""" | |
| # First conv block: (1, 28, 28) -> (32, 14, 14) | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = self.pool1(x) | |
| # Second conv block: (32, 14, 14) -> (64, 7, 7) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = self.pool2(x) | |
| # Third conv block: (64, 7, 7) -> (128, 3, 3) | |
| x = F.relu(self.bn3(self.conv3(x))) | |
| x = self.pool3(x) | |
| # Flatten for FC layers | |
| x = x.view(x.size(0), -1) | |
| # Fully connected layers with dropout | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout1(x) | |
| x = F.relu(self.fc2(x)) | |
| x = self.dropout2(x) | |
| x = self.fc3(x) | |
| return x | |
| class TinyCNN(nn.Module): | |
| """ | |
| Tiny CNN for MNIST using Global Avg Pooling | |
| Trainable parameters: 94,410 | |
| """ | |
| def __init__(self, num_classes=10): | |
| super(TinyCNN, self).__init__() | |
| # First conv block | |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(32) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| # Second conv block | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| # Third conv block | |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm2d(128) | |
| self.pool3 = nn.MaxPool2d(2, 2) | |
| # Global average pooling | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| # Final FC (input = 128 channels after GAP) | |
| self.fc = nn.Linear(128, num_classes) | |
| def forward(self, x): | |
| x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
| x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
| x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
| x = self.avgpool(x) # (batch, 128, 1, 1) | |
| x = x.view(x.size(0), -1) # (batch, 128) | |
| x = self.fc(x) # (batch, num_classes) | |
| return x | |
| class MiniCNN(nn.Module): | |
| """ | |
| Mini CNN for MNIST using only 2 convolution layers + Global Avg Pooling | |
| Trainable parameters: ~19K | |
| """ | |
| def __init__(self, num_classes=10): | |
| super(MiniCNN, self).__init__() | |
| # First CNV | |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(32) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| # Second CNV | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| # Global Average Pooling | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| # Fully connected classifier | |
| self.fc = nn.Linear(64, num_classes) | |
| def forward(self, x): | |
| x = self.pool1(F.relu(self.bn1(self.conv1(x)))) # (batch, 32, 14, 14) | |
| x = self.pool2(F.relu(self.bn2(self.conv2(x)))) # (batch, 64, 7, 7) | |
| x = self.avgpool(x) # (batch, 64, 1, 1) | |
| x = x.view(x.size(0), -1) # (batch, 64) | |
| x = self.fc(x) # (batch, num_classes) | |
| return x | |