HF-Demo / models_attack.py
felix2703's picture
Fix model architectures to match trained checkpoints
da4f171
"""
Attack CNN Model Architectures
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class StandardCNN(nn.Module):
"""
Standard CNN Model (Original)
Architecture: 3 Conv blocks with BatchNorm + 3 FC layers
Parameters: ~817K
"""
def __init__(self, num_classes=10, dropout_rate=0.5):
super(StandardCNN, self).__init__()
# First convolutional block
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# Second convolutional block
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# Third convolutional block
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# Calculate the flattened size after convolutions
self.flattened_size = 128 * 3 * 3 # 28x28 -> 14x14 -> 7x7 -> 3x3
# Fully connected layers
self.fc1 = nn.Linear(self.flattened_size, 512)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(512, 256)
self.dropout2 = nn.Dropout(dropout_rate)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, x, return_logits=False):
# Conv block 1
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.pool1(x)
# Conv block 2
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.pool2(x)
# Conv block 3
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.pool3(x)
# Flatten
x = x.view(x.size(0), -1)
# FC layers
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
logits = self.fc3(x)
if return_logits:
return logits
return F.softmax(logits, dim=1)
class LighterCNN(nn.Module):
"""
Lighter CNN Model
Architecture: 3 Conv blocks with fewer filters + Global Average Pooling
Parameters: ~94K
"""
def __init__(self, num_classes=10, dropout_rate=0.5):
super(LighterCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2,2)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2,2) # 28->14->7->3
self.gap = nn.AdaptiveAvgPool2d(1) # (B,128,1,1)
self.fc = nn.Linear(128, num_classes)
def forward(self, x, return_logits=False):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.gap(x).view(x.size(0), -1) # (B,128)
logits = self.fc(x)
return logits if return_logits else F.softmax(logits, dim=1)
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super(DepthwiseSeparableConv, self).__init__()
self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1,
groups=in_ch, bias=False) # depthwise
self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False) # pointwise
self.bn = nn.BatchNorm2d(out_ch)
def forward(self, x):
x = self.dw(x)
x = self.pw(x)
return F.relu(self.bn(x), inplace=True)
class DepthwiseCNN(nn.Module):
"""
Depthwise Separable CNN
Ultra-efficient using Depthwise Separable Convolutions
Parameters: ~1.4K
"""
def __init__(self, num_classes=10, dropout_rate=0.5):
super(DepthwiseCNN, self).__init__()
# Stem: 1 -> 8, reduce size with stride=2 (28->14)
self.stem = nn.Sequential(
nn.Conv2d(1, 8, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU(inplace=True),
)
# DS blocks
self.ds1 = DepthwiseSeparableConv(8, 16, stride=1)
self.ds2 = DepthwiseSeparableConv(16, 32, stride=2) # 14->7
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, num_classes)
def forward(self, x, return_logits=False):
x = self.stem(x) # B, 8, 14, 14
x = self.ds1(x) # B,16,14,14
x = self.ds2(x) # B,32, 7, 7
x = self.gap(x).flatten(1) # B,32
logits = self.fc(x) # B,10
return logits if return_logits else F.softmax(logits, dim=1)