""" CNN Model Architectures for MNIST Classification This module provides CNN models for digit recognition: - BaselineCNN: Simple 2-layer CNN (target: 98-99% accuracy) - ImprovedCNN: Enhanced architecture with batch normalization - Model utilities: parameter counting, architecture summary Usage: from scripts.models import BaselineCNN model = BaselineCNN() output = model(images) # (batch, 10) logits """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple class BaselineCNN(nn.Module): """ Baseline CNN for MNIST classification. Architecture: Input: (batch, 1, 28, 28) Conv1: 1 -> 32 channels, 3x3 kernel, padding=1 ReLU + MaxPool(2x2) -> (batch, 32, 14, 14) Conv2: 32 -> 64 channels, 3x3 kernel, padding=1 ReLU + MaxPool(2x2) -> (batch, 64, 7, 7) Flatten -> (batch, 3136) FC1: 3136 -> 128, ReLU, Dropout(0.5) FC2: 128 -> 10 (output logits) Design Rationale: - 2 conv layers: Balance between simplicity and capacity - 32->64 filters: Standard progression, proven effective - Dropout 0.5: Prevent overfitting on small dataset - No batch norm: Keep baseline simple Expected Performance: - Parameters: ~110k - Test accuracy: 98-99% - Training time: ~5-10 min on GPU """ def __init__(self, dropout_rate: float = 0.5): """ Initialize baseline CNN. Args: dropout_rate: Dropout probability (default 0.5) """ super(BaselineCNN, self).__init__() # Convolutional layers self.conv1 = nn.Conv2d( in_channels=1, out_channels=32, kernel_size=3, padding=1 ) self.conv2 = nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, padding=1 ) # Pooling layer (shared) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Fully connected layers # After two pooling layers: 28->14->7, so 64*7*7 = 3136 self.fc1 = nn.Linear(64 * 7 * 7, 128) self.fc2 = nn.Linear(128, 10) # Dropout for regularization self.dropout = nn.Dropout(p=dropout_rate) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. Args: x: Input tensor of shape (batch, 1, 28, 28) Returns: Output logits of shape (batch, 10) """ # Conv block 1: Conv -> ReLU -> Pool x = self.conv1(x) # (batch, 32, 28, 28) x = F.relu(x) x = self.pool(x) # (batch, 32, 14, 14) # Conv block 2: Conv -> ReLU -> Pool x = self.conv2(x) # (batch, 64, 14, 14) x = F.relu(x) x = self.pool(x) # (batch, 64, 7, 7) # Flatten x = x.view(-1, 64 * 7 * 7) # (batch, 3136) # Fully connected layers x = self.fc1(x) # (batch, 128) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) # (batch, 10) return x class ImprovedCNN(nn.Module): """ Enhanced CNN with batch normalization and deeper architecture. Architecture: Conv1: 1 -> 32, BatchNorm, ReLU, MaxPool Conv2: 32 -> 64, BatchNorm, ReLU, MaxPool Conv3: 64 -> 128, BatchNorm, ReLU, MaxPool Flatten FC1: 128*3*3 -> 256, BatchNorm, ReLU, Dropout(0.5) FC2: 256 -> 10 Expected Performance: - Parameters: ~200k - Test accuracy: 99%+ - Converges faster than baseline """ def __init__(self, dropout_rate: float = 0.5): """ Initialize improved CNN. Args: dropout_rate: Dropout probability (default 0.5) """ super(ImprovedCNN, self).__init__() # Convolutional layers with batch normalization self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Fully connected layers # After three pooling layers: 28->14->7->3, so 128*3*3 = 1152 self.fc1 = nn.Linear(128 * 3 * 3, 256) self.bn_fc = nn.BatchNorm1d(256) self.fc2 = nn.Linear(256, 10) self.dropout = nn.Dropout(p=dropout_rate) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. Args: x: Input tensor of shape (batch, 1, 28, 28) Returns: Output logits of shape (batch, 10) """ # Conv block 1 x = self.conv1(x) x = self.bn1(x) x = F.relu(x) x = self.pool(x) # (batch, 32, 14, 14) # Conv block 2 x = self.conv2(x) x = self.bn2(x) x = F.relu(x) x = self.pool(x) # (batch, 64, 7, 7) # Conv block 3 x = self.conv3(x) x = self.bn3(x) x = F.relu(x) x = self.pool(x) # (batch, 128, 3, 3) # Flatten x = x.view(-1, 128 * 3 * 3) # Fully connected layers x = self.fc1(x) x = self.bn_fc(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) return x def count_parameters(model: nn.Module) -> Tuple[int, int]: """ Count total and trainable parameters in model. Args: model: PyTorch model Returns: Tuple of (total_params, trainable_params) """ total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return total_params, trainable_params def get_model_summary( model: nn.Module, input_size: Tuple[int, ...] = (1, 1, 28, 28) ) -> str: """ Generate model architecture summary. Args: model: PyTorch model input_size: Input tensor size (batch, channels, height, width) Returns: Formatted string with model summary """ total_params, trainable_params = count_parameters(model) summary = [] summary.append("=" * 60) summary.append(f"Model: {model.__class__.__name__}") summary.append("=" * 60) summary.append(f"Input size: {input_size}") summary.append(f"Total parameters: {total_params:,}") summary.append(f"Trainable parameters: {trainable_params:,}") # Assuming float32 model_size_mb = total_params * 4 / (1024**2) summary.append(f"Model size (MB): {model_size_mb:.2f}") summary.append("=" * 60) return "\n".join(summary) def test_model(model: nn.Module, device: str = 'cpu') -> bool: """ Test model with dummy input. Args: model: PyTorch model device: Device to run on ('cpu' or 'cuda') Returns: True if test passes, False otherwise """ try: model = model.to(device) model.eval() # Create dummy input dummy_input = torch.randn(4, 1, 28, 28).to(device) # Forward pass with torch.no_grad(): output = model(dummy_input) # Check output shape assert output.shape == (4, 10), f"Expected shape (4, 10), got {output.shape}" # Check output is finite assert torch.isfinite(output).all(), "Output contains NaN or Inf" print("✓ Model test passed") print(f" Input shape: {dummy_input.shape}") print(f" Output shape: {output.shape}") print(f" Output range: [{output.min():.4f}, {output.max():.4f}]") return True except Exception as e: print(f"✗ Model test failed: {e}") return False if __name__ == "__main__": """Test model instantiation and forward pass.""" print("Testing BaselineCNN:") print() # Create model model = BaselineCNN() print(get_model_summary(model)) print() # Test forward pass test_model(model) print() # Test improved model print("=" * 60) print("Testing ImprovedCNN:") print() model_improved = ImprovedCNN() print(get_model_summary(model_improved)) print() test_model(model_improved)