HF-Demo / models_shifted.py
felix2703's picture
Fix model architectures to match trained checkpoints
da4f171
"""
Shifted MNIST CNN Model Architectures
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
"""
CNN Model for MNIST digit classification with shifted labels
Architecture: Conv-BN-ReLU-Pool x3 + FC-Dropout x2 + FC
Trainable parameters: 817,354
"""
def __init__(self, num_classes=10, dropout_rate=0.5):
super(CNNModel, self).__init__()
# First convolutional block
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# Second convolutional block
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# Third convolutional block
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flattened_size = 128 * 3 * 3
# Fully connected layers with dropout
self.fc1 = nn.Linear(self.flattened_size, 512)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(512, 256)
self.dropout2 = nn.Dropout(dropout_rate)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, x):
"""Forward pass through the network"""
# First conv block: (1, 28, 28) -> (32, 14, 14)
x = F.relu(self.bn1(self.conv1(x)))
x = self.pool1(x)
# Second conv block: (32, 14, 14) -> (64, 7, 7)
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool2(x)
# Third conv block: (64, 7, 7) -> (128, 3, 3)
x = F.relu(self.bn3(self.conv3(x)))
x = self.pool3(x)
# Flatten for FC layers
x = x.view(x.size(0), -1)
# Fully connected layers with dropout
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
class TinyCNN(nn.Module):
"""
Tiny CNN for MNIST using Global Avg Pooling
Trainable parameters: 94,410
"""
def __init__(self, num_classes=10):
super(TinyCNN, self).__init__()
# First conv block
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second conv block
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Third conv block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final FC (input = 128 channels after GAP)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.avgpool(x) # (batch, 128, 1, 1)
x = x.view(x.size(0), -1) # (batch, 128)
x = self.fc(x) # (batch, num_classes)
return x
class MiniCNN(nn.Module):
"""
Mini CNN for MNIST using only 2 convolution layers + Global Avg Pooling
Trainable parameters: ~19K
"""
def __init__(self, num_classes=10):
super(MiniCNN, self).__init__()
# First CNV
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
# Second CNV
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
# Global Average Pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Fully connected classifier
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x)))) # (batch, 32, 14, 14)
x = self.pool2(F.relu(self.bn2(self.conv2(x)))) # (batch, 64, 7, 7)
x = self.avgpool(x) # (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # (batch, 64)
x = self.fc(x) # (batch, num_classes)
return x