Mr7Explorer commited on
Commit
50fe11b
·
verified ·
1 Parent(s): 4353d5b

Create aspp.py

Browse files
Files changed (1) hide show
  1. aspp.py +119 -0
aspp.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.modules.deform_conv import DeformableConv2d
5
+ from config import Config
6
+
7
+
8
+ config = Config()
9
+
10
+
11
+ class _ASPPModule(nn.Module):
12
+ def __init__(self, in_channels, planes, kernel_size, padding, dilation):
13
+ super(_ASPPModule, self).__init__()
14
+ self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
15
+ stride=1, padding=padding, dilation=dilation, bias=False)
16
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
17
+ self.relu = nn.ReLU(inplace=True)
18
+
19
+ def forward(self, x):
20
+ x = self.atrous_conv(x)
21
+ x = self.bn(x)
22
+
23
+ return self.relu(x)
24
+
25
+
26
+ class ASPP(nn.Module):
27
+ def __init__(self, in_channels=64, out_channels=None, output_stride=16):
28
+ super(ASPP, self).__init__()
29
+ self.down_scale = 1
30
+ if out_channels is None:
31
+ out_channels = in_channels
32
+ self.in_channelster = 256 // self.down_scale
33
+ if output_stride == 16:
34
+ dilations = [1, 6, 12, 18]
35
+ elif output_stride == 8:
36
+ dilations = [1, 12, 24, 36]
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0])
41
+ self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1])
42
+ self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2])
43
+ self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3])
44
+
45
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
46
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
47
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
48
+ nn.ReLU(inplace=True))
49
+ self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.dropout = nn.Dropout(0.5)
53
+
54
+ def forward(self, x):
55
+ x1 = self.aspp1(x)
56
+ x2 = self.aspp2(x)
57
+ x3 = self.aspp3(x)
58
+ x4 = self.aspp4(x)
59
+ x5 = self.global_avg_pool(x)
60
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
61
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
62
+
63
+ x = self.conv1(x)
64
+ x = self.bn1(x)
65
+ x = self.relu(x)
66
+
67
+ return self.dropout(x)
68
+
69
+
70
+ ##################### Deformable
71
+ class _ASPPModuleDeformable(nn.Module):
72
+ def __init__(self, in_channels, planes, kernel_size, padding):
73
+ super(_ASPPModuleDeformable, self).__init__()
74
+ self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
75
+ stride=1, padding=padding, bias=False)
76
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
77
+ self.relu = nn.ReLU(inplace=True)
78
+
79
+ def forward(self, x):
80
+ x = self.atrous_conv(x)
81
+ x = self.bn(x)
82
+
83
+ return self.relu(x)
84
+
85
+
86
+ class ASPPDeformable(nn.Module):
87
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
88
+ super(ASPPDeformable, self).__init__()
89
+ self.down_scale = 1
90
+ if out_channels is None:
91
+ out_channels = in_channels
92
+ self.in_channelster = 256 // self.down_scale
93
+
94
+ self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
95
+ self.aspp_deforms = nn.ModuleList([
96
+ _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes
97
+ ])
98
+
99
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
100
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
101
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
102
+ nn.ReLU(inplace=True))
103
+ self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
104
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
105
+ self.relu = nn.ReLU(inplace=True)
106
+ self.dropout = nn.Dropout(0.5)
107
+
108
+ def forward(self, x):
109
+ x1 = self.aspp1(x)
110
+ x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
111
+ x5 = self.global_avg_pool(x)
112
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
113
+ x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
114
+
115
+ x = self.conv1(x)
116
+ x = self.bn1(x)
117
+ x = self.relu(x)
118
+
119
+ return self.dropout(x)