import torch.nn as nn from transformers import PreTrainedModel from .configuration_simplecnn import SimpleCNNConfig class SimpleCNN(PreTrainedModel): config_class = SimpleCNNConfig def __init__(self, config): super().__init__(config) self.conv_layers = nn.Sequential( nn.Conv2d(config.input_channels, 16, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.fc_layers = nn.Sequential( nn.Flatten(), nn.Linear(32 * 7 * 7, 64), nn.ReLU(), nn.Linear(64, config.num_classes) ) def forward(self, x): x = self.conv_layers(x) return self.fc_layers(x)