""" Shifted MNIST CNN Model Architectures """ import torch import torch.nn as nn import torch.nn.functional as F class CNNModel(nn.Module): """ CNN Model for MNIST digit classification with shifted labels Architecture: Conv-BN-ReLU-Pool x3 + FC-Dropout x2 + FC Trainable parameters: 817,354 """ def __init__(self, num_classes=10, dropout_rate=0.5): super(CNNModel, self).__init__() # First convolutional block self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Second convolutional block self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # Third convolutional block self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.flattened_size = 128 * 3 * 3 # Fully connected layers with dropout self.fc1 = nn.Linear(self.flattened_size, 512) self.dropout1 = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(512, 256) self.dropout2 = nn.Dropout(dropout_rate) self.fc3 = nn.Linear(256, num_classes) def forward(self, x): """Forward pass through the network""" # First conv block: (1, 28, 28) -> (32, 14, 14) x = F.relu(self.bn1(self.conv1(x))) x = self.pool1(x) # Second conv block: (32, 14, 14) -> (64, 7, 7) x = F.relu(self.bn2(self.conv2(x))) x = self.pool2(x) # Third conv block: (64, 7, 7) -> (128, 3, 3) x = F.relu(self.bn3(self.conv3(x))) x = self.pool3(x) # Flatten for FC layers x = x.view(x.size(0), -1) # Fully connected layers with dropout x = F.relu(self.fc1(x)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) x = self.fc3(x) return x class TinyCNN(nn.Module): """ Tiny CNN for MNIST using Global Avg Pooling Trainable parameters: 94,410 """ def __init__(self, num_classes=10): super(TinyCNN, self).__init__() # First conv block self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2, 2) # Second conv block self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2, 2) # Third conv block self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2, 2) # Global average pooling self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Final FC (input = 128 channels after GAP) self.fc = nn.Linear(128, num_classes) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.avgpool(x) # (batch, 128, 1, 1) x = x.view(x.size(0), -1) # (batch, 128) x = self.fc(x) # (batch, num_classes) return x class MiniCNN(nn.Module): """ Mini CNN for MNIST using only 2 convolution layers + Global Avg Pooling Trainable parameters: ~19K """ def __init__(self, num_classes=10): super(MiniCNN, self).__init__() # First CNV self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2, 2) # Second CNV self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2, 2) # Global Average Pooling self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Fully connected classifier self.fc = nn.Linear(64, num_classes) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) # (batch, 32, 14, 14) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) # (batch, 64, 7, 7) x = self.avgpool(x) # (batch, 64, 1, 1) x = x.view(x.size(0), -1) # (batch, 64) x = self.fc(x) # (batch, num_classes) return x