import torch import torch.nn as nn from torchvision import models from torchvision.models import ResNet18_Weights class SimpleCNN(nn.Module): """ A minimalist CNN model as a baseline. Consists of two convolutional layers followed by a fully connected layer. """ def __init__(self, num_classes=6): super(SimpleCNN, self).__init__() # First Convolutional Block: Takes 3 channels (RGB) as input and # outputs 16 self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 224 -> 112 # Second Convolutional Block self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 112 -> 56 # Adaptive Pooling ensures the output is always 7x7, regardless of # input size self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # Classification Layer self.fc = nn.Linear(32 * 7 * 7, num_classes) def forward(self, x): """ Defines the forward pass of the data through the network. """ x = self.pool1(self.relu1(self.conv1(x))) x = self.pool2(self.relu2(self.conv2(x))) x = self.adaptive_pool(x) x = torch.flatten(x, 1) # Flatten for the linear layer x = self.fc(x) return x class DeepCNN(nn.Module): """ A deeper CNN model with Batch Normalization and Dropout for regularization. Better suited for more complex image features. """ def __init__(self, num_classes=6): super(DeepCNN, self).__init__() # Block 1 self.layer1 = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 112 ) # Block 2 self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 56 ) # Block 3 self.layer3 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 28 ) self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # Classifier with Dropout to prevent overfitting self.classifier = nn.Sequential( nn.Linear(128 * 7 * 7, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) def forward(self, x): """ Forward pass through the sequential layers. """ x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.adaptive_pool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x class ResNet18Transfer(nn.Module): """ Transfer Learning model based on ResNet18. Allows loading pretrained weights and freezing the backbone. """ def __init__(self, num_classes=6, pretrained=True, freeze_backbone=False): super(ResNet18Transfer, self).__init__() # Load the ResNet18 model weights = ResNet18_Weights.DEFAULT if pretrained else None self.backbone = models.resnet18(weights=weights) # Freeze the backbone if requested if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False # Adjust the final fully connected layer (fc) # ResNet18 fc has 512 input features by default in_features = self.backbone.fc.in_features self.backbone.fc = nn.Linear(in_features, num_classes) def forward(self, x): """ Uses the ResNet backbone for feature extraction and classification. """ return self.backbone(x)