File size: 4,081 Bytes
0b86da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet18_Weights


class SimpleCNN(nn.Module):
    """

    A minimalist CNN model as a baseline.

    Consists of two convolutional layers followed by a fully connected layer.

    """

    def __init__(self, num_classes=6):
        super(SimpleCNN, self).__init__()
        # First Convolutional Block: Takes 3 channels (RGB) as input and
        # outputs 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 224 -> 112

        # Second Convolutional Block
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 112 -> 56

        # Adaptive Pooling ensures the output is always 7x7, regardless of
        # input size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))

        # Classification Layer
        self.fc = nn.Linear(32 * 7 * 7, num_classes)

    def forward(self, x):
        """

        Defines the forward pass of the data through the network.

        """
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.adaptive_pool(x)
        x = torch.flatten(x, 1)  # Flatten for the linear layer
        x = self.fc(x)
        return x


class DeepCNN(nn.Module):
    """

    A deeper CNN model with Batch Normalization and Dropout for regularization.

    Better suited for more complex image features.

    """

    def __init__(self, num_classes=6):
        super(DeepCNN, self).__init__()

        # Block 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 112
        )

        # Block 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 56
        )

        # Block 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 28
        )

        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))

        # Classifier with Dropout to prevent overfitting
        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes)
        )

    def forward(self, x):
        """

        Forward pass through the sequential layers.

        """
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.adaptive_pool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class ResNet18Transfer(nn.Module):
    """

    Transfer Learning model based on ResNet18.

    Allows loading pretrained weights and freezing the backbone.

    """

    def __init__(self, num_classes=6, pretrained=True, freeze_backbone=False):
        super(ResNet18Transfer, self).__init__()

        # Load the ResNet18 model
        weights = ResNet18_Weights.DEFAULT if pretrained else None
        self.backbone = models.resnet18(weights=weights)

        # Freeze the backbone if requested
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Adjust the final fully connected layer (fc)
        # ResNet18 fc has 512 input features by default
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        """

        Uses the ResNet backbone for feature extraction and classification.

        """
        return self.backbone(x)