import torch import torch.nn as nn class SimplifiedAlexNet(nn.Module): def __init__(self, num_classes=10): super(SimplifiedAlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ) self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(128 * 4 * 4, 512), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(512, num_classes), ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x