import torch from torch import nn class ALexNet(nn.Module): def __init__(self, input_shape: int, hidden_units: int, output_shape): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(input_shape, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.block2 = nn.Sequential( nn.Conv2d(64, 192, kernel_size=3, padding=1), nn.BatchNorm2d(192), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.block3 = nn.Sequential( nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.BatchNorm2d(384), nn.ReLU() ) self.block4 = nn.Sequential( nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU() ) self.block5 = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2) ) with torch.no_grad(): dummy = torch.zeros(1, input_shape, 32, 32) # change 224 if needed x = self.block1(dummy) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) self.flattened_size = x.view(1, -1).shape[1] self.flatten = nn.Flatten() self.fc1 = nn.Sequential( nn.Linear(in_features=self.flattened_size, out_features=1024), nn.ReLU(), nn.Dropout(0.5) ) self.fc2 = nn.Sequential( nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(0.5) ) self.classifier = nn.Sequential( nn.Linear(1024, output_shape) ) def forward(self, x: torch.Tensor): x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.flatten(x) x = self.fc1(x) x = self.fc2(x) x = self.classifier(x) return x