""" Attack CNN Model Architectures """ import torch import torch.nn as nn import torch.nn.functional as F class StandardCNN(nn.Module): """ Standard CNN Model (Original) Architecture: 3 Conv blocks with BatchNorm + 3 FC layers Parameters: ~817K """ def __init__(self, num_classes=10, dropout_rate=0.5): super(StandardCNN, self).__init__() # First convolutional block self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Second convolutional block self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # Third convolutional block self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # Calculate the flattened size after convolutions self.flattened_size = 128 * 3 * 3 # 28x28 -> 14x14 -> 7x7 -> 3x3 # Fully connected layers self.fc1 = nn.Linear(self.flattened_size, 512) self.dropout1 = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(512, 256) self.dropout2 = nn.Dropout(dropout_rate) self.fc3 = nn.Linear(256, num_classes) def forward(self, x, return_logits=False): # Conv block 1 x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.pool1(x) # Conv block 2 x = self.conv2(x) x = self.bn2(x) x = F.relu(x) x = self.pool2(x) # Conv block 3 x = self.conv3(x) x = self.bn3(x) x = F.relu(x) x = self.pool3(x) # Flatten x = x.view(x.size(0), -1) # FC layers x = F.relu(self.fc1(x)) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) logits = self.fc3(x) if return_logits: return logits return F.softmax(logits, dim=1) class LighterCNN(nn.Module): """ Lighter CNN Model Architecture: 3 Conv blocks with fewer filters + Global Average Pooling Parameters: ~94K """ def __init__(self, num_classes=10, dropout_rate=0.5): super(LighterCNN, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2,2) self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2,2) # 28->14->7->3 self.gap = nn.AdaptiveAvgPool2d(1) # (B,128,1,1) self.fc = nn.Linear(128, num_classes) def forward(self, x, return_logits=False): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.gap(x).view(x.size(0), -1) # (B,128) logits = self.fc(x) return logits if return_logits else F.softmax(logits, dim=1) class DepthwiseSeparableConv(nn.Module): def __init__(self, in_ch, out_ch, stride=1): super(DepthwiseSeparableConv, self).__init__() self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False) # depthwise self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False) # pointwise self.bn = nn.BatchNorm2d(out_ch) def forward(self, x): x = self.dw(x) x = self.pw(x) return F.relu(self.bn(x), inplace=True) class DepthwiseCNN(nn.Module): """ Depthwise Separable CNN Ultra-efficient using Depthwise Separable Convolutions Parameters: ~1.4K """ def __init__(self, num_classes=10, dropout_rate=0.5): super(DepthwiseCNN, self).__init__() # Stem: 1 -> 8, reduce size with stride=2 (28->14) self.stem = nn.Sequential( nn.Conv2d(1, 8, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(8), nn.ReLU(inplace=True), ) # DS blocks self.ds1 = DepthwiseSeparableConv(8, 16, stride=1) self.ds2 = DepthwiseSeparableConv(16, 32, stride=2) # 14->7 self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(32, num_classes) def forward(self, x, return_logits=False): x = self.stem(x) # B, 8, 14, 14 x = self.ds1(x) # B,16,14,14 x = self.ds2(x) # B,32, 7, 7 x = self.gap(x).flatten(1) # B,32 logits = self.fc(x) # B,10 return logits if return_logits else F.softmax(logits, dim=1)