Spaces:
Running on Zero
Running on Zero
| import torch.nn as nn | |
| from config import IMAGE_SIZE | |
| class SimpleCNN(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| conv1_channels: int = 16, | |
| conv2_channels: int = 32, | |
| kernel_size: int = 3, | |
| dropout: float = 0.2, | |
| fc_dim: int = 128, | |
| ): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, conv1_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| flattened_dim = conv2_channels * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(flattened_dim, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, num_classes), | |
| ) | |
| def forward(self, x): | |
| return self.classifier(self.features(x)) |