poppingout1325's picture
End of training
5f58fbe verified
from collections import OrderedDict
import torch.nn as nn
class SmallCNN(nn.Module):
def __init__(self, drop=0.5):
super(SmallCNN, self).__init__()
self.num_channels = 1
self.num_labels = 10
activ = nn.ReLU(True)
self.feature_extractor = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(self.num_channels, 32, 3)),
('relu1', activ),
('conv2', nn.Conv2d(32, 32, 3)),
('relu2', activ),
('maxpool1', nn.MaxPool2d(2, 2)),
('conv3', nn.Conv2d(32, 64, 3)),
('relu3', activ),
('conv4', nn.Conv2d(64, 64, 3)),
('relu4', activ),
('maxpool2', nn.MaxPool2d(2, 2)),
]))
self.classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(64 * 4 * 4, 200)),
('relu1', activ),
('drop', nn.Dropout(drop)),
('fc2', nn.Linear(200, 200)),
('relu2', activ),
('fc3', nn.Linear(200, self.num_labels)),
]))
for m in self.modules():
if isinstance(m, (nn.Conv2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.constant_(self.classifier.fc3.weight, 0)
nn.init.constant_(self.classifier.fc3.bias, 0)
def forward(self, input):
features = self.feature_extractor(input)
logits = self.classifier(features.view(-1, 64 * 4 * 4))
return logits
def small_cnn(num_classes=10):
return SmallCNN(num_classes=num_classes)