Spaces:
Running on Zero
Running on Zero
| import torch.nn as nn | |
| from torchvision import models | |
| class SimpleCNN(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| conv1_channels: int = 16, | |
| conv2_channels: int = 32, | |
| kernel_size: int = 3, | |
| dropout: float = 0.2, | |
| fc_dim: int = 128, | |
| ): | |
| super().__init__() | |
| weights = models.ResNet18_Weights.DEFAULT | |
| self.backbone = models.resnet18(weights=weights) | |
| in_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, num_classes), | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) |