| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class MNIST_CNN(nn.Module): |
| """ |
| Enhanced CNN for MNIST classification |
| Matches the saved model architecture: conv1=16, conv2=32 |
| """ |
| |
| def __init__(self, num_classes=10, dropout_rate=0.2): |
| super(MNIST_CNN, self).__init__() |
| |
| |
| self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(16) |
| |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(32) |
| |
| |
| |
| self.fc1 = nn.Linear(32 * 7 * 7, 128) |
| self.dropout = nn.Dropout(dropout_rate) |
| self.fc2 = nn.Linear(128, num_classes) |
| |
| |
| self._initialize_weights() |
| |
| def _initialize_weights(self): |
| """Initialize weights using Xavier initialization""" |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| nn.init.zeros_(m.bias) |
| |
| def forward(self, x): |
| |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = F.max_pool2d(x, 2) |
| |
| |
| x = F.relu(self.bn2(self.conv2(x))) |
| x = F.max_pool2d(x, 2) |
| |
| |
| x = x.view(x.size(0), -1) |
| |
| |
| x = F.relu(self.fc1(x)) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| |
| return x |
|
|
| |
| def create_mnist_cnn(num_classes=10, dropout_rate=0.2): |
| """Factory function to create MNIST CNN""" |
| return MNIST_CNN(num_classes=num_classes, dropout_rate=dropout_rate) |
|
|
|
|
|
|