shwethd commited on
Commit
d43be06
·
verified ·
1 Parent(s): 7d2465b

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +119 -0
model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicBlock(nn.Module):
7
+ expansion = 1
8
+
9
+ def __init__(self, in_planes, planes, stride=1):
10
+ super(BasicBlock, self).__init__()
11
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
12
+ self.bn1 = nn.BatchNorm2d(planes)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
14
+ self.bn2 = nn.BatchNorm2d(planes)
15
+
16
+ self.shortcut = nn.Sequential()
17
+ if stride != 1 or in_planes != self.expansion * planes:
18
+ self.shortcut = nn.Sequential(
19
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
20
+ nn.BatchNorm2d(self.expansion * planes)
21
+ )
22
+
23
+ def forward(self, x):
24
+ out = F.relu(self.bn1(self.conv1(x)))
25
+ out = self.bn2(self.conv2(out))
26
+ out += self.shortcut(x)
27
+ out = F.relu(out)
28
+ return out
29
+
30
+
31
+ class Bottleneck(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, in_planes, planes, stride=1):
35
+ super(Bottleneck, self).__init__()
36
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
41
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
42
+
43
+ self.shortcut = nn.Sequential()
44
+ if stride != 1 or in_planes != self.expansion * planes:
45
+ self.shortcut = nn.Sequential(
46
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
47
+ nn.BatchNorm2d(self.expansion * planes)
48
+ )
49
+
50
+ def forward(self, x):
51
+ out = F.relu(self.bn1(self.conv1(x)))
52
+ out = F.relu(self.bn2(self.conv2(out)))
53
+ out = self.bn3(self.conv3(out))
54
+ out += self.shortcut(x)
55
+ out = F.relu(out)
56
+ return out
57
+
58
+
59
+ class ResNet(nn.Module):
60
+ def __init__(self, block, num_blocks, num_classes=100):
61
+ super(ResNet, self).__init__()
62
+ self.in_planes = 64
63
+
64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(64)
66
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
67
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
68
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
69
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
70
+ self.dropout = nn.Dropout(0.5)
71
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
72
+
73
+ def _make_layer(self, block, planes, num_blocks, stride):
74
+ strides = [stride] + [1] * (num_blocks - 1)
75
+ layers = []
76
+ for stride in strides:
77
+ layers.append(block(self.in_planes, planes, stride))
78
+ self.in_planes = planes * block.expansion
79
+ return nn.Sequential(*layers)
80
+
81
+ def forward(self, x):
82
+ out = F.relu(self.bn1(self.conv1(x)))
83
+ out = self.layer1(out)
84
+ out = self.layer2(out)
85
+ out = self.layer3(out)
86
+ out = self.layer4(out)
87
+ out = F.avg_pool2d(out, 4)
88
+ out = out.view(out.size(0), -1)
89
+ out = self.dropout(out)
90
+ out = self.linear(out)
91
+ return out
92
+
93
+
94
+ def ResNet18():
95
+ return ResNet(BasicBlock, [2, 2, 2, 2])
96
+
97
+
98
+ def ResNet34():
99
+ return ResNet(BasicBlock, [3, 4, 6, 3])
100
+
101
+
102
+ def ResNet50():
103
+ return ResNet(Bottleneck, [3, 4, 6, 3])
104
+
105
+
106
+ def ResNet101():
107
+ return ResNet(Bottleneck, [3, 4, 23, 3])
108
+
109
+
110
+ def ResNet152():
111
+ return ResNet(Bottleneck, [3, 8, 36, 3])
112
+
113
+
114
+ def test():
115
+ net = ResNet18()
116
+ y = net(torch.randn(1, 3, 32, 32))
117
+ print(y.size())
118
+
119
+ # test()