| import torch.nn as nn | |
| class AlexNet(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super(AlexNet, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv3d(1, 64, kernel_size=11, stride=4, padding=2), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool3d(kernel_size=3, stride=2), | |
| nn.Conv3d(64, 192, kernel_size=5, padding=2), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool3d(kernel_size=3, stride=2), | |
| nn.Conv3d(192, 384, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(384, 256, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv3d(256, 256, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool3d(kernel_size=3, stride=2), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(), | |
| nn.Linear(256 * 6 * 6 * 6, 4096), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(), | |
| nn.Linear(4096, 4096), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(4096, num_classes), | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| for weight in self.parameters(): | |
| weight.data.uniform_(-0.1, 0.1) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), 256 * 6 * 6 * 6) | |
| x = self.classifier(x) | |
| return x | |