| from collections import OrderedDict |
| import torch.nn as nn |
|
|
|
|
| class SmallCNN(nn.Module): |
| def __init__(self, drop=0.5): |
| super(SmallCNN, self).__init__() |
|
|
| self.num_channels = 1 |
| self.num_labels = 10 |
|
|
| activ = nn.ReLU(True) |
|
|
| self.feature_extractor = nn.Sequential(OrderedDict([ |
| ('conv1', nn.Conv2d(self.num_channels, 32, 3)), |
| ('relu1', activ), |
| ('conv2', nn.Conv2d(32, 32, 3)), |
| ('relu2', activ), |
| ('maxpool1', nn.MaxPool2d(2, 2)), |
| ('conv3', nn.Conv2d(32, 64, 3)), |
| ('relu3', activ), |
| ('conv4', nn.Conv2d(64, 64, 3)), |
| ('relu4', activ), |
| ('maxpool2', nn.MaxPool2d(2, 2)), |
| ])) |
|
|
| self.classifier = nn.Sequential(OrderedDict([ |
| ('fc1', nn.Linear(64 * 4 * 4, 200)), |
| ('relu1', activ), |
| ('drop', nn.Dropout(drop)), |
| ('fc2', nn.Linear(200, 200)), |
| ('relu2', activ), |
| ('fc3', nn.Linear(200, self.num_labels)), |
| ])) |
|
|
| for m in self.modules(): |
| if isinstance(m, (nn.Conv2d)): |
| nn.init.kaiming_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(self.classifier.fc3.weight, 0) |
| nn.init.constant_(self.classifier.fc3.bias, 0) |
|
|
| def forward(self, input): |
| features = self.feature_extractor(input) |
| logits = self.classifier(features.view(-1, 64 * 4 * 4)) |
| return logits |
|
|
| def small_cnn(num_classes=10): |
| return SmallCNN(num_classes=num_classes) |