saneshashank commited on
Commit
c734473
·
verified ·
1 Parent(s): 2b36aa6

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +231 -0
model.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import antialiased_cnns
5
+
6
+
7
+ def drop_path(x, drop_prob=0.0, training=False):
8
+ """Drop paths (Stochastic Depth) per sample."""
9
+ if drop_prob == 0. or not training:
10
+ return x
11
+ keep_prob = 1 - drop_prob
12
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
13
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
14
+ random_tensor.floor_()
15
+ output = x.div(keep_prob) * random_tensor
16
+ return output
17
+
18
+
19
+ class BasicBlock(nn.Module):
20
+ """Original ResNet Basic Block with Stochastic Depth"""
21
+ expansion = 1
22
+
23
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None, drop_prob=0.0, use_blurpool=False):
24
+ super().__init__()
25
+ self.use_blurpool = use_blurpool
26
+ self.stride = stride
27
+
28
+ # Modify conv1 based on stride and use_blurpool
29
+ if self.use_blurpool and self.stride == 2:
30
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
31
+ stride=1, padding=1, bias=False)
32
+ self.blurpool = antialiased_cnns.BlurPool(out_channels, stride=2)
33
+ else:
34
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
35
+ stride=stride, padding=1, bias=False)
36
+
37
+ self.bn1 = nn.BatchNorm2d(out_channels)
38
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
39
+ stride=1, padding=1, bias=False)
40
+ self.bn2 = nn.BatchNorm2d(out_channels)
41
+ self.downsample = downsample
42
+ self.drop_prob = drop_prob
43
+
44
+ def forward(self, x):
45
+ identity = x
46
+
47
+ out = self.conv1(x)
48
+ out = self.bn1(out)
49
+ out = F.relu(out, inplace=True)
50
+
51
+ # Apply blurpool after conv1 if downsampling with blurpool
52
+ if self.use_blurpool and self.stride == 2:
53
+ out = self.blurpool(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+
61
+ out = drop_path(out, self.drop_prob, self.training)
62
+ out += identity
63
+ out = F.relu(out, inplace=True)
64
+
65
+ return out
66
+
67
+
68
+ class BottleneckBlock(nn.Module):
69
+ """Original ResNet Bottleneck Block with Stochastic Depth"""
70
+ expansion = 4
71
+
72
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None, drop_prob=0.0, use_blurpool=False):
73
+ super().__init__()
74
+ self.use_blurpool = use_blurpool
75
+ self.stride = stride
76
+
77
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
78
+ self.bn1 = nn.BatchNorm2d(out_channels)
79
+
80
+ # Modify conv2 based on stride and use_blurpool
81
+ if self.use_blurpool and self.stride == 2:
82
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
83
+ stride=1, padding=1, bias=False)
84
+ self.blurpool = antialiased_cnns.BlurPool(out_channels, stride=2)
85
+ else:
86
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
87
+ stride=stride, padding=1, bias=False)
88
+
89
+ self.bn2 = nn.BatchNorm2d(out_channels)
90
+
91
+ self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
92
+ kernel_size=1, bias=False)
93
+ self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
94
+
95
+ self.downsample = downsample
96
+ self.drop_prob = drop_prob
97
+
98
+
99
+ def forward(self, x):
100
+ identity = x
101
+
102
+ out = self.conv1(x)
103
+ out = self.bn1(out)
104
+ out = F.relu(out, inplace=True)
105
+
106
+ out = self.conv2(out)
107
+ out = self.bn2(out)
108
+ out = F.relu(out, inplace=True)
109
+
110
+ # Apply blurpool after conv2 if downsampling with blurpool
111
+ if self.use_blurpool and self.stride == 2:
112
+ out = self.blurpool(out)
113
+
114
+
115
+ out = self.conv3(out)
116
+ out = self.bn3(out)
117
+
118
+ if self.downsample is not None:
119
+ identity = self.downsample(x)
120
+
121
+ out = drop_path(out, self.drop_prob, self.training)
122
+ out += identity
123
+ out = F.relu(out, inplace=True)
124
+
125
+ return out
126
+
127
+
128
+ class ResNet(nn.Module):
129
+ def __init__(self, block, layers, num_classes=1000, drop_path_rate=0.2, use_blurpool=False):
130
+ super().__init__()
131
+ self.in_channels = 64
132
+ self.use_blurpool = use_blurpool
133
+
134
+ # Initial conv layer
135
+ # Apply blurpool if use_blurpool is True and stride is 2
136
+ if self.use_blurpool:
137
+ self.conv1 = nn.Sequential(
138
+ nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False),
139
+ nn.BatchNorm2d(64),
140
+ nn.ReLU(inplace=True),
141
+ antialiased_cnns.BlurPool(64, stride=2)
142
+ )
143
+ else:
144
+ self.conv1 = nn.Sequential(
145
+ nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
146
+ nn.BatchNorm2d(64),
147
+ nn.ReLU(inplace=True)
148
+ )
149
+
150
+ # Initial pooling layer (always MaxPool2d stride 2) replace maxpool by MaxBlurPool if use_blurpool is True
151
+ if self.use_blurpool:
152
+ self.maxpool_or_blurpool = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1), antialiased_cnns.BlurPool(64, stride=2))
153
+ else:
154
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
155
+
156
+
157
+
158
+ # Calculate total number of blocks
159
+ total_blocks = sum(layers)
160
+ # Linear drop path rate schedule
161
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_blocks)]
162
+
163
+ # Track current block index
164
+ block_idx = 0
165
+
166
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=1,
167
+ drop_probs=dpr[block_idx:block_idx+layers[0]], use_blurpool=use_blurpool)
168
+ block_idx += layers[0]
169
+
170
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
171
+ drop_probs=dpr[block_idx:block_idx+layers[1]], use_blurpool=use_blurpool)
172
+ block_idx += layers[1]
173
+
174
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
175
+ drop_probs=dpr[block_idx:block_idx+layers[2]], use_blurpool=use_blurpool)
176
+ block_idx += layers[2]
177
+
178
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
179
+ drop_probs=dpr[block_idx:block_idx+layers[3]], use_blurpool=use_blurpool)
180
+
181
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
182
+ self.fc = nn.Conv2d(512 * block.expansion, num_classes, kernel_size=1)
183
+
184
+ def _make_layer(self, block, out_channels, blocks, stride, drop_probs, use_blurpool):
185
+ downsample = None
186
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
187
+ # Downsample path
188
+ # If use_blurpool is True and stride is 2, replace strided conv with conv stride 1 + blurpool stride 2
189
+ if use_blurpool and stride == 2:
190
+ downsample = nn.Sequential(
191
+ nn.Conv2d(self.in_channels, out_channels * block.expansion,
192
+ kernel_size=1, stride=1, bias=False), # Conv stride 1
193
+ nn.BatchNorm2d(out_channels * block.expansion),
194
+ antialiased_cnns.BlurPool(out_channels * block.expansion, stride=2) # BlurPool stride 2
195
+ )
196
+ else:
197
+ downsample = nn.Sequential(
198
+ nn.Conv2d(self.in_channels, out_channels * block.expansion,
199
+ kernel_size=1, stride=stride, bias=False),
200
+ nn.BatchNorm2d(out_channels * block.expansion)
201
+ )
202
+
203
+
204
+ layers = []
205
+ # First block in the layer handles downsampling
206
+ layers.append(block(self.in_channels, out_channels, stride, downsample, drop_probs[0], use_blurpool=use_blurpool))
207
+ self.in_channels = out_channels * block.expansion
208
+
209
+ # Subsequent blocks have stride 1
210
+ for i in range(1, blocks):
211
+ layers.append(block(self.in_channels, out_channels, stride=1, drop_prob=drop_probs[i], use_blurpool=use_blurpool))
212
+
213
+ return nn.Sequential(*layers)
214
+
215
+ def forward(self, x):
216
+ x = self.conv1(x)
217
+ # The original ResNet has maxpool after conv1 replace maxpool by MaxBlurPool if use_blurpool is True
218
+ if self.use_blurpool:
219
+ x = self.maxpool_or_blurpool(x)
220
+ else:
221
+ x = self.maxpool(x)
222
+ x = self.layer1(x)
223
+ x = self.layer2(x)
224
+ x = self.layer3(x)
225
+ x = self.layer4(x)
226
+
227
+ x = self.avgpool(x)
228
+ x = self.fc(x)
229
+ x = torch.flatten(x, 1)
230
+
231
+ return x