import torch.nn as nn from torchvision import models class SimpleCNN(nn.Module): def __init__( self, num_classes: int, conv1_channels: int = 16, conv2_channels: int = 32, kernel_size: int = 3, dropout: float = 0.2, fc_dim: int = 128, ): super().__init__() weights = models.ResNet18_Weights.DEFAULT self.backbone = models.resnet18(weights=weights) in_features = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, fc_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) def forward(self, x): return self.backbone(x)