Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from torchvision.models import ResNet18_Weights | |
| class SimpleCNN(nn.Module): | |
| """ | |
| A minimalist CNN model as a baseline. | |
| Consists of two convolutional layers followed by a fully connected layer. | |
| """ | |
| def __init__(self, num_classes=6): | |
| super(SimpleCNN, self).__init__() | |
| # First Convolutional Block: Takes 3 channels (RGB) as input and | |
| # outputs 16 | |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) | |
| self.relu1 = nn.ReLU() | |
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 224 -> 112 | |
| # Second Convolutional Block | |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) | |
| self.relu2 = nn.ReLU() | |
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 112 -> 56 | |
| # Adaptive Pooling ensures the output is always 7x7, regardless of | |
| # input size | |
| self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) | |
| # Classification Layer | |
| self.fc = nn.Linear(32 * 7 * 7, num_classes) | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass of the data through the network. | |
| """ | |
| x = self.pool1(self.relu1(self.conv1(x))) | |
| x = self.pool2(self.relu2(self.conv2(x))) | |
| x = self.adaptive_pool(x) | |
| x = torch.flatten(x, 1) # Flatten for the linear layer | |
| x = self.fc(x) | |
| return x | |
| class DeepCNN(nn.Module): | |
| """ | |
| A deeper CNN model with Batch Normalization and Dropout for regularization. | |
| Better suited for more complex image features. | |
| """ | |
| def __init__(self, num_classes=6): | |
| super(DeepCNN, self).__init__() | |
| # Block 1 | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 112 | |
| ) | |
| # Block 2 | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 56 | |
| ) | |
| # Block 3 | |
| self.layer3 = nn.Sequential( | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2), # 28 | |
| ) | |
| self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) | |
| # Classifier with Dropout to prevent overfitting | |
| self.classifier = nn.Sequential( | |
| nn.Linear(128 * 7 * 7, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| """ | |
| Forward pass through the sequential layers. | |
| """ | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.adaptive_pool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.classifier(x) | |
| return x | |
| class ResNet18Transfer(nn.Module): | |
| """ | |
| Transfer Learning model based on ResNet18. | |
| Allows loading pretrained weights and freezing the backbone. | |
| """ | |
| def __init__(self, num_classes=6, pretrained=True, freeze_backbone=False): | |
| super(ResNet18Transfer, self).__init__() | |
| # Load the ResNet18 model | |
| weights = ResNet18_Weights.DEFAULT if pretrained else None | |
| self.backbone = models.resnet18(weights=weights) | |
| # Freeze the backbone if requested | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| # Adjust the final fully connected layer (fc) | |
| # ResNet18 fc has 512 input features by default | |
| in_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Linear(in_features, num_classes) | |
| def forward(self, x): | |
| """ | |
| Uses the ResNet backbone for feature extraction and classification. | |
| """ | |
| return self.backbone(x) | |