import torch.nn as nn from config import IMAGE_SIZE class SimpleCNN(nn.Module): def __init__( self, num_classes: int, conv1_channels: int = 16, conv2_channels: int = 32, kernel_size: int = 3, dropout: float = 0.2, fc_dim: int = 128, ): super().__init__() padding = kernel_size // 2 self.features = nn.Sequential( nn.Conv2d(3, conv1_channels, kernel_size=kernel_size, padding=padding), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding), nn.ReLU(), nn.MaxPool2d(2), ) flattened_dim = conv2_channels * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(flattened_dim, fc_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) def forward(self, x): return self.classifier(self.features(x))