File size: 1,081 Bytes
01ce719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn as nn

from config import IMAGE_SIZE


class SimpleCNN(nn.Module):
    def __init__(
        self,
        num_classes: int,
        conv1_channels: int = 16,
        conv2_channels: int = 32,
        kernel_size: int = 3,
        dropout: float = 0.2,
        fc_dim: int = 128,
    ):
        super().__init__()

        padding = kernel_size // 2

        self.features = nn.Sequential(
            nn.Conv2d(3, conv1_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        flattened_dim = conv2_channels * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_dim, fc_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fc_dim, num_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))