Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| class ALexNet(nn.Module): | |
| def __init__(self, input_shape: int, hidden_units: int, output_shape): | |
| super().__init__() | |
| self.block1 = nn.Sequential( | |
| nn.Conv2d(input_shape, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2) | |
| ) | |
| self.block2 = nn.Sequential( | |
| nn.Conv2d(64, 192, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(192), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2) | |
| ) | |
| self.block3 = nn.Sequential( | |
| nn.Conv2d(192, 384, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(384), | |
| nn.ReLU() | |
| ) | |
| self.block4 = nn.Sequential( | |
| nn.Conv2d(384, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU() | |
| ) | |
| self.block5 = nn.Sequential( | |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2) | |
| ) | |
| with torch.no_grad(): | |
| dummy = torch.zeros(1, input_shape, 32, 32) # change 224 if needed | |
| x = self.block1(dummy) | |
| x = self.block2(x) | |
| x = self.block3(x) | |
| x = self.block4(x) | |
| x = self.block5(x) | |
| self.flattened_size = x.view(1, -1).shape[1] | |
| self.flatten = nn.Flatten() | |
| self.fc1 = nn.Sequential( | |
| nn.Linear(in_features=self.flattened_size, | |
| out_features=1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.5) | |
| ) | |
| self.fc2 = nn.Sequential( | |
| nn.Linear(1024, 1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.5) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(1024, output_shape) | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| x = self.block3(x) | |
| x = self.block4(x) | |
| x = self.block5(x) | |
| x = self.flatten(x) | |
| x = self.fc1(x) | |
| x = self.fc2(x) | |
| x = self.classifier(x) | |
| return x | |