PawanratRung commited on
Commit
3be7bf8
·
verified ·
1 Parent(s): 8bfc38f

Create AugmentCE2P.py

Browse files
Files changed (1) hide show
  1. 3rdparty/SCHP/networks/AugmentCE2P.py +481 -0
3rdparty/SCHP/networks/AugmentCE2P.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : AugmentCE2P.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from torch.nn import BatchNorm2d, functional as F, LeakyReLU
18
+
19
+ affine_par = True
20
+ pretrained_settings = {
21
+ "resnet101": {
22
+ "imagenet": {
23
+ "input_space": "BGR",
24
+ "input_size": [3, 224, 224],
25
+ "input_range": [0, 1],
26
+ "mean": [0.406, 0.456, 0.485],
27
+ "std": [0.225, 0.224, 0.229],
28
+ "num_classes": 1000,
29
+ }
30
+ },
31
+ }
32
+
33
+
34
+ def conv3x3(in_planes, out_planes, stride=1):
35
+ "3x3 convolution with padding"
36
+ return nn.Conv2d(
37
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
38
+ )
39
+
40
+
41
+ class Bottleneck(nn.Module):
42
+ expansion = 4
43
+
44
+ def __init__(
45
+ self,
46
+ inplanes,
47
+ planes,
48
+ stride=1,
49
+ dilation=1,
50
+ downsample=None,
51
+ fist_dilation=1,
52
+ multi_grid=1,
53
+ ):
54
+ super(Bottleneck, self).__init__()
55
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
+ self.bn1 = BatchNorm2d(planes)
57
+ self.conv2 = nn.Conv2d(
58
+ planes,
59
+ planes,
60
+ kernel_size=3,
61
+ stride=stride,
62
+ padding=dilation * multi_grid,
63
+ dilation=dilation * multi_grid,
64
+ bias=False,
65
+ )
66
+ self.bn2 = BatchNorm2d(planes)
67
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68
+ self.bn3 = BatchNorm2d(planes * 4)
69
+ self.relu = nn.ReLU(inplace=False)
70
+ self.relu_inplace = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.dilation = dilation
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out = out + residual
93
+ out = self.relu_inplace(out)
94
+
95
+ return out
96
+
97
+
98
+ class PSPModule(nn.Module):
99
+ """
100
+ Reference:
101
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
102
+ """
103
+
104
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
105
+ super(PSPModule, self).__init__()
106
+
107
+ self.stages = []
108
+ self.stages = nn.ModuleList(
109
+ [self._make_stage(features, out_features, size) for size in sizes]
110
+ )
111
+ self.bottleneck = nn.Sequential(
112
+ nn.Conv2d(
113
+ features + len(sizes) * out_features,
114
+ out_features,
115
+ kernel_size=3,
116
+ padding=1,
117
+ dilation=1,
118
+ bias=False,
119
+ ),
120
+ BatchNorm2d(out_features),
121
+ LeakyReLU(),
122
+ )
123
+
124
+ def _make_stage(self, features, out_features, size):
125
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
126
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
127
+ return nn.Sequential(
128
+ prior,
129
+ conv,
130
+ # bn
131
+ BatchNorm2d(out_features),
132
+ LeakyReLU(),
133
+ )
134
+
135
+ def forward(self, feats):
136
+ h, w = feats.size(2), feats.size(3)
137
+ priors = [
138
+ F.interpolate(
139
+ input=stage(feats), size=(h, w), mode="bilinear", align_corners=True
140
+ )
141
+ for stage in self.stages
142
+ ] + [feats]
143
+ bottle = self.bottleneck(torch.cat(priors, 1))
144
+ return bottle
145
+
146
+
147
+ class ASPPModule(nn.Module):
148
+ """
149
+ Reference:
150
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
151
+ """
152
+
153
+ def __init__(
154
+ self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)
155
+ ):
156
+ super(ASPPModule, self).__init__()
157
+
158
+ self.conv1 = nn.Sequential(
159
+ nn.AdaptiveAvgPool2d((1, 1)),
160
+ nn.Conv2d(
161
+ features,
162
+ inner_features,
163
+ kernel_size=1,
164
+ padding=0,
165
+ dilation=1,
166
+ bias=False,
167
+ ),
168
+ # InPlaceABNSync(inner_features)
169
+ BatchNorm2d(inner_features),
170
+ LeakyReLU(),
171
+ )
172
+ self.conv2 = nn.Sequential(
173
+ nn.Conv2d(
174
+ features,
175
+ inner_features,
176
+ kernel_size=1,
177
+ padding=0,
178
+ dilation=1,
179
+ bias=False,
180
+ ),
181
+ BatchNorm2d(inner_features),
182
+ LeakyReLU(),
183
+ )
184
+ self.conv3 = nn.Sequential(
185
+ nn.Conv2d(
186
+ features,
187
+ inner_features,
188
+ kernel_size=3,
189
+ padding=dilations[0],
190
+ dilation=dilations[0],
191
+ bias=False,
192
+ ),
193
+ BatchNorm2d(inner_features),
194
+ LeakyReLU(),
195
+ )
196
+ self.conv4 = nn.Sequential(
197
+ nn.Conv2d(
198
+ features,
199
+ inner_features,
200
+ kernel_size=3,
201
+ padding=dilations[1],
202
+ dilation=dilations[1],
203
+ bias=False,
204
+ ),
205
+ BatchNorm2d(inner_features),
206
+ LeakyReLU(),
207
+ )
208
+ self.conv5 = nn.Sequential(
209
+ nn.Conv2d(
210
+ features,
211
+ inner_features,
212
+ kernel_size=3,
213
+ padding=dilations[2],
214
+ dilation=dilations[2],
215
+ bias=False,
216
+ ),
217
+ BatchNorm2d(inner_features),
218
+ LeakyReLU(),
219
+ )
220
+
221
+ self.bottleneck = nn.Sequential(
222
+ nn.Conv2d(
223
+ inner_features * 5,
224
+ out_features,
225
+ kernel_size=1,
226
+ padding=0,
227
+ dilation=1,
228
+ bias=False,
229
+ ),
230
+ BatchNorm2d(inner_features),
231
+ LeakyReLU(),
232
+ nn.Dropout2d(0.1),
233
+ )
234
+
235
+ def forward(self, x):
236
+ _, _, h, w = x.size()
237
+
238
+ feat1 = F.interpolate(
239
+ self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
240
+ )
241
+
242
+ feat2 = self.conv2(x)
243
+ feat3 = self.conv3(x)
244
+ feat4 = self.conv4(x)
245
+ feat5 = self.conv5(x)
246
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
247
+
248
+ bottle = self.bottleneck(out)
249
+ return bottle
250
+
251
+
252
+ class Edge_Module(nn.Module):
253
+ """
254
+ Edge Learning Branch
255
+ """
256
+
257
+ def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
258
+ super(Edge_Module, self).__init__()
259
+
260
+ self.conv1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
263
+ ),
264
+ BatchNorm2d(mid_fea),
265
+ LeakyReLU(),
266
+ )
267
+ self.conv2 = nn.Sequential(
268
+ nn.Conv2d(
269
+ in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
270
+ ),
271
+ BatchNorm2d(mid_fea),
272
+ LeakyReLU(),
273
+ )
274
+ self.conv3 = nn.Sequential(
275
+ nn.Conv2d(
276
+ in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False
277
+ ),
278
+ BatchNorm2d(mid_fea),
279
+ LeakyReLU(),
280
+ )
281
+ self.conv4 = nn.Conv2d(
282
+ mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True
283
+ )
284
+ # self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
285
+
286
+ def forward(self, x1, x2, x3):
287
+ _, _, h, w = x1.size()
288
+
289
+ edge1_fea = self.conv1(x1)
290
+ # edge1 = self.conv4(edge1_fea)
291
+ edge2_fea = self.conv2(x2)
292
+ edge2 = self.conv4(edge2_fea)
293
+ edge3_fea = self.conv3(x3)
294
+ edge3 = self.conv4(edge3_fea)
295
+
296
+ edge2_fea = F.interpolate(
297
+ edge2_fea, size=(h, w), mode="bilinear", align_corners=True
298
+ )
299
+ edge3_fea = F.interpolate(
300
+ edge3_fea, size=(h, w), mode="bilinear", align_corners=True
301
+ )
302
+ edge2 = F.interpolate(edge2, size=(h, w), mode="bilinear", align_corners=True)
303
+ edge3 = F.interpolate(edge3, size=(h, w), mode="bilinear", align_corners=True)
304
+
305
+ # edge = torch.cat([edge1, edge2, edge3], dim=1)
306
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
307
+ # edge = self.conv5(edge)
308
+
309
+ # return edge, edge_fea
310
+ return edge_fea
311
+
312
+
313
+ class Decoder_Module(nn.Module):
314
+ """
315
+ Parsing Branch Decoder Module.
316
+ """
317
+
318
+ def __init__(self, num_classes):
319
+ super(Decoder_Module, self).__init__()
320
+ self.conv1 = nn.Sequential(
321
+ nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
322
+ BatchNorm2d(256),
323
+ LeakyReLU(),
324
+ )
325
+ self.conv2 = nn.Sequential(
326
+ nn.Conv2d(
327
+ 256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
328
+ ),
329
+ BatchNorm2d(48),
330
+ LeakyReLU(),
331
+ )
332
+ self.conv3 = nn.Sequential(
333
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
334
+ BatchNorm2d(256),
335
+ LeakyReLU(),
336
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
337
+ BatchNorm2d(256),
338
+ LeakyReLU(),
339
+ )
340
+
341
+ # self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
342
+
343
+ def forward(self, xt, xl):
344
+ _, _, h, w = xl.size()
345
+ xt = F.interpolate(
346
+ self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
347
+ )
348
+ xl = self.conv2(xl)
349
+ x = torch.cat([xt, xl], dim=1)
350
+ x = self.conv3(x)
351
+ # seg = self.conv4(x)
352
+ # return seg, x
353
+ return x
354
+
355
+
356
+ class ResNet(nn.Module):
357
+ def __init__(self, block, layers, num_classes):
358
+ self.inplanes = 128
359
+ super(ResNet, self).__init__()
360
+ self.conv1 = conv3x3(3, 64, stride=2)
361
+ self.bn1 = BatchNorm2d(64)
362
+ self.relu1 = nn.ReLU(inplace=False)
363
+ self.conv2 = conv3x3(64, 64)
364
+ self.bn2 = BatchNorm2d(64)
365
+ self.relu2 = nn.ReLU(inplace=False)
366
+ self.conv3 = conv3x3(64, 128)
367
+ self.bn3 = BatchNorm2d(128)
368
+ self.relu3 = nn.ReLU(inplace=False)
369
+
370
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
371
+
372
+ self.layer1 = self._make_layer(block, 64, layers[0])
373
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
374
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
375
+ self.layer4 = self._make_layer(
376
+ block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1)
377
+ )
378
+
379
+ self.context_encoding = PSPModule(2048, 512)
380
+
381
+ self.edge = Edge_Module()
382
+ self.decoder = Decoder_Module(num_classes)
383
+
384
+ self.fushion = nn.Sequential(
385
+ nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
386
+ BatchNorm2d(256),
387
+ LeakyReLU(),
388
+ nn.Dropout2d(0.1),
389
+ nn.Conv2d(
390
+ 256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True
391
+ ),
392
+ )
393
+
394
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
395
+ downsample = None
396
+ if stride != 1 or self.inplanes != planes * block.expansion:
397
+ downsample = nn.Sequential(
398
+ nn.Conv2d(
399
+ self.inplanes,
400
+ planes * block.expansion,
401
+ kernel_size=1,
402
+ stride=stride,
403
+ bias=False,
404
+ ),
405
+ BatchNorm2d(planes * block.expansion, affine=affine_par),
406
+ )
407
+
408
+ layers = []
409
+ generate_multi_grid = lambda index, grids: (
410
+ grids[index % len(grids)] if isinstance(grids, tuple) else 1
411
+ )
412
+ layers.append(
413
+ block(
414
+ self.inplanes,
415
+ planes,
416
+ stride,
417
+ dilation=dilation,
418
+ downsample=downsample,
419
+ multi_grid=generate_multi_grid(0, multi_grid),
420
+ )
421
+ )
422
+ self.inplanes = planes * block.expansion
423
+ for i in range(1, blocks):
424
+ layers.append(
425
+ block(
426
+ self.inplanes,
427
+ planes,
428
+ dilation=dilation,
429
+ multi_grid=generate_multi_grid(i, multi_grid),
430
+ )
431
+ )
432
+
433
+ return nn.Sequential(*layers)
434
+
435
+ def forward(self, x):
436
+ x = self.relu1(self.bn1(self.conv1(x)))
437
+ x = self.relu2(self.bn2(self.conv2(x)))
438
+ x = self.relu3(self.bn3(self.conv3(x)))
439
+ x = self.maxpool(x)
440
+ x2 = self.layer1(x)
441
+ x3 = self.layer2(x2)
442
+ x4 = self.layer3(x3)
443
+ x5 = self.layer4(x4)
444
+ x = self.context_encoding(x5)
445
+ # parsing_result, parsing_fea = self.decoder(x, x2)
446
+ parsing_fea = self.decoder(x, x2)
447
+ # Edge Branch
448
+ # edge_result, edge_fea = self.edge(x2, x3, x4)
449
+ edge_fea = self.edge(x2, x3, x4)
450
+ # Fusion Branch
451
+ x = torch.cat([parsing_fea, edge_fea], dim=1)
452
+ fusion_result = self.fushion(x)
453
+ # return [[parsing_result, fusion_result], [edge_result]]
454
+ return fusion_result
455
+
456
+
457
+ def initialize_pretrained_model(
458
+ model, settings, pretrained="./models/resnet101-imagenet.pth"
459
+ ):
460
+ model.input_space = settings["input_space"]
461
+ model.input_size = settings["input_size"]
462
+ model.input_range = settings["input_range"]
463
+ model.mean = settings["mean"]
464
+ model.std = settings["std"]
465
+
466
+ if pretrained is not None:
467
+ saved_state_dict = torch.load(pretrained)
468
+ new_params = model.state_dict().copy()
469
+ for i in saved_state_dict:
470
+ i_parts = i.split(".")
471
+ if not i_parts[0] == "fc":
472
+ new_params[".".join(i_parts[0:])] = saved_state_dict[i]
473
+ model.load_state_dict(new_params)
474
+
475
+
476
+ def resnet101(num_classes=20, pretrained="./models/resnet101-imagenet.pth"):
477
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
478
+ settings = pretrained_settings["resnet101"]["imagenet"]
479
+ initialize_pretrained_model(model, settings, pretrained)
480
+ return model
481
+