Anupam202224 commited on
Commit
bd24b82
·
verified ·
1 Parent(s): 5138b32

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +39 -139
model.py CHANGED
@@ -1,21 +1,14 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import torchvision
5
 
6
- from resnet import Resnet18
7
- # from modules.bn import InPlaceABNSync as BatchNorm2d
8
 
9
 
10
  class ConvBNReLU(nn.Module):
11
- def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
12
  super(ConvBNReLU, self).__init__()
13
- self.conv = nn.Conv2d(in_chan,
14
- out_chan,
15
- kernel_size = ks,
16
- stride = stride,
17
- padding = padding,
18
- bias = False)
19
  self.bn = nn.BatchNorm2d(out_chan)
20
  self.init_weight()
21
 
@@ -25,13 +18,13 @@ class ConvBNReLU(nn.Module):
25
  return x
26
 
27
  def init_weight(self):
28
- for ly in self.children():
29
- if isinstance(ly, nn.Conv2d):
30
- nn.init.kaiming_normal_(ly.weight, a=1)
31
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
32
 
33
  class BiSeNetOutput(nn.Module):
34
- def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
35
  super(BiSeNetOutput, self).__init__()
36
  self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
37
  self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
@@ -43,28 +36,23 @@ class BiSeNetOutput(nn.Module):
43
  return x
44
 
45
  def init_weight(self):
46
- for ly in self.children():
47
- if isinstance(ly, nn.Conv2d):
48
- nn.init.kaiming_normal_(ly.weight, a=1)
49
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
50
 
51
  def get_params(self):
52
- wd_params, nowd_params = [], []
53
- for name, module in self.named_modules():
54
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
55
- wd_params.append(module.weight)
56
- if not module.bias is None:
57
- nowd_params.append(module.bias)
58
- elif isinstance(module, nn.BatchNorm2d):
59
- nowd_params += list(module.parameters())
60
  return wd_params, nowd_params
61
 
62
 
63
  class AttentionRefinementModule(nn.Module):
64
- def __init__(self, in_chan, out_chan, *args, **kwargs):
65
  super(AttentionRefinementModule, self).__init__()
66
  self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
67
- self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
68
  self.bn_atten = nn.BatchNorm2d(out_chan)
69
  self.sigmoid_atten = nn.Sigmoid()
70
  self.init_weight()
@@ -79,14 +67,13 @@ class AttentionRefinementModule(nn.Module):
79
  return out
80
 
81
  def init_weight(self):
82
- for ly in self.children():
83
- if isinstance(ly, nn.Conv2d):
84
- nn.init.kaiming_normal_(ly.weight, a=1)
85
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
86
 
87
 
88
  class ContextPath(nn.Module):
89
- def __init__(self, *args, **kwargs):
90
  super(ContextPath, self).__init__()
91
  self.resnet = Resnet18()
92
  self.arm16 = AttentionRefinementModule(256, 128)
@@ -95,8 +82,6 @@ class ContextPath(nn.Module):
95
  self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
96
  self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
97
 
98
- self.init_weight()
99
-
100
  def forward(self, x):
101
  H0, W0 = x.size()[2:]
102
  feat8, feat16, feat32 = self.resnet(x)
@@ -118,77 +103,22 @@ class ContextPath(nn.Module):
118
  feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
119
  feat16_up = self.conv_head16(feat16_up)
120
 
121
- return feat8, feat16_up, feat32_up # x8, x8, x16
122
-
123
- def init_weight(self):
124
- for ly in self.children():
125
- if isinstance(ly, nn.Conv2d):
126
- nn.init.kaiming_normal_(ly.weight, a=1)
127
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
128
-
129
- def get_params(self):
130
- wd_params, nowd_params = [], []
131
- for name, module in self.named_modules():
132
- if isinstance(module, (nn.Linear, nn.Conv2d)):
133
- wd_params.append(module.weight)
134
- if not module.bias is None:
135
- nowd_params.append(module.bias)
136
- elif isinstance(module, nn.BatchNorm2d):
137
- nowd_params += list(module.parameters())
138
- return wd_params, nowd_params
139
-
140
-
141
- ### This is not used, since I replace this with the resnet feature with the same size
142
- class SpatialPath(nn.Module):
143
- def __init__(self, *args, **kwargs):
144
- super(SpatialPath, self).__init__()
145
- self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
146
- self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
147
- self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
148
- self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
149
- self.init_weight()
150
-
151
- def forward(self, x):
152
- feat = self.conv1(x)
153
- feat = self.conv2(feat)
154
- feat = self.conv3(feat)
155
- feat = self.conv_out(feat)
156
- return feat
157
 
158
  def init_weight(self):
159
  for ly in self.children():
160
  if isinstance(ly, nn.Conv2d):
161
  nn.init.kaiming_normal_(ly.weight, a=1)
162
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
163
-
164
- def get_params(self):
165
- wd_params, nowd_params = [], []
166
- for name, module in self.named_modules():
167
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
168
- wd_params.append(module.weight)
169
- if not module.bias is None:
170
- nowd_params.append(module.bias)
171
- elif isinstance(module, nn.BatchNorm2d):
172
- nowd_params += list(module.parameters())
173
- return wd_params, nowd_params
174
 
175
 
176
  class FeatureFusionModule(nn.Module):
177
- def __init__(self, in_chan, out_chan, *args, **kwargs):
178
  super(FeatureFusionModule, self).__init__()
179
  self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
180
- self.conv1 = nn.Conv2d(out_chan,
181
- out_chan//4,
182
- kernel_size = 1,
183
- stride = 1,
184
- padding = 0,
185
- bias = False)
186
- self.conv2 = nn.Conv2d(out_chan//4,
187
- out_chan,
188
- kernel_size = 1,
189
- stride = 1,
190
- padding = 0,
191
- bias = False)
192
  self.relu = nn.ReLU(inplace=True)
193
  self.sigmoid = nn.Sigmoid()
194
  self.init_weight()
@@ -206,38 +136,27 @@ class FeatureFusionModule(nn.Module):
206
  return feat_out
207
 
208
  def init_weight(self):
209
- for ly in self.children():
210
- if isinstance(ly, nn.Conv2d):
211
- nn.init.kaiming_normal_(ly.weight, a=1)
212
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
213
-
214
- def get_params(self):
215
- wd_params, nowd_params = [], []
216
- for name, module in self.named_modules():
217
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
218
- wd_params.append(module.weight)
219
- if not module.bias is None:
220
- nowd_params.append(module.bias)
221
- elif isinstance(module, nn.BatchNorm2d):
222
- nowd_params += list(module.parameters())
223
- return wd_params, nowd_params
224
 
225
 
226
  class BiSeNet(nn.Module):
227
- def __init__(self, n_classes, *args, **kwargs):
228
  super(BiSeNet, self).__init__()
229
  self.cp = ContextPath()
230
- ## here self.sp is deleted
231
  self.ffm = FeatureFusionModule(256, 256)
232
  self.conv_out = BiSeNetOutput(256, 256, n_classes)
233
  self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
234
  self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
235
- self.init_weight()
236
 
237
  def forward(self, x):
238
  H, W = x.size()[2:]
239
- feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
240
- feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
241
  feat_fuse = self.ffm(feat_sp, feat_cp8)
242
 
243
  feat_out = self.conv_out(feat_fuse)
@@ -253,28 +172,9 @@ class BiSeNet(nn.Module):
253
  for ly in self.children():
254
  if isinstance(ly, nn.Conv2d):
255
  nn.init.kaiming_normal_(ly.weight, a=1)
256
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
 
257
 
258
  def get_params(self):
259
  wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
260
- for name, child in self.named_children():
261
- child_wd_params, child_nowd_params = child.get_params()
262
- if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
263
- lr_mul_wd_params += child_wd_params
264
- lr_mul_nowd_params += child_nowd_params
265
- else:
266
- wd_params += child_wd_params
267
- nowd_params += child_nowd_params
268
- return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
269
-
270
-
271
- if __name__ == "__main__":
272
- net = BiSeNet(19)
273
- #net.cuda()
274
- net.eval()
275
- in_ten = torch.randn(16, 3, 640, 480)
276
- out, out16, out32 = net(in_ten)
277
- print(out.shape)
278
-
279
- net.get_params()
280
-
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
+ from resnet import Resnet18 # Ensure that the Resnet18 class is correctly defined in this module
 
6
 
7
 
8
  class ConvBNReLU(nn.Module):
9
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
10
  super(ConvBNReLU, self).__init__()
11
+ self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
 
 
 
 
 
12
  self.bn = nn.BatchNorm2d(out_chan)
13
  self.init_weight()
14
 
 
18
  return x
19
 
20
  def init_weight(self):
21
+ nn.init.kaiming_normal_(self.conv.weight, a=1)
22
+ if self.conv.bias is not None:
23
+ nn.init.constant_(self.conv.bias, 0)
24
+
25
 
26
  class BiSeNetOutput(nn.Module):
27
+ def __init__(self, in_chan, mid_chan, n_classes):
28
  super(BiSeNetOutput, self).__init__()
29
  self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
30
  self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
 
36
  return x
37
 
38
  def init_weight(self):
39
+ nn.init.kaiming_normal_(self.conv_out.weight, a=1)
40
+ if self.conv_out.bias is not None:
41
+ nn.init.constant_(self.conv_out.bias, 0)
 
42
 
43
  def get_params(self):
44
+ wd_params = [self.conv_out.weight]
45
+ nowd_params = []
46
+ if self.conv_out.bias is not None:
47
+ nowd_params.append(self.conv_out.bias)
 
 
 
 
48
  return wd_params, nowd_params
49
 
50
 
51
  class AttentionRefinementModule(nn.Module):
52
+ def __init__(self, in_chan, out_chan):
53
  super(AttentionRefinementModule, self).__init__()
54
  self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
55
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
56
  self.bn_atten = nn.BatchNorm2d(out_chan)
57
  self.sigmoid_atten = nn.Sigmoid()
58
  self.init_weight()
 
67
  return out
68
 
69
  def init_weight(self):
70
+ nn.init.kaiming_normal_(self.conv_atten.weight, a=1)
71
+ if self.conv_atten.bias is not None:
72
+ nn.init.constant_(self.conv_atten.bias, 0)
 
73
 
74
 
75
  class ContextPath(nn.Module):
76
+ def __init__(self):
77
  super(ContextPath, self).__init__()
78
  self.resnet = Resnet18()
79
  self.arm16 = AttentionRefinementModule(256, 128)
 
82
  self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
83
  self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
84
 
 
 
85
  def forward(self, x):
86
  H0, W0 = x.size()[2:]
87
  feat8, feat16, feat32 = self.resnet(x)
 
103
  feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
104
  feat16_up = self.conv_head16(feat16_up)
105
 
106
+ return feat8, feat16_up, feat32_up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def init_weight(self):
109
  for ly in self.children():
110
  if isinstance(ly, nn.Conv2d):
111
  nn.init.kaiming_normal_(ly.weight, a=1)
112
+ if ly.bias is not None:
113
+ nn.init.constant_(ly.bias, 0)
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  class FeatureFusionModule(nn.Module):
117
+ def __init__(self, in_chan, out_chan):
118
  super(FeatureFusionModule, self).__init__()
119
  self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
120
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
121
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
 
 
 
 
 
 
 
 
 
 
122
  self.relu = nn.ReLU(inplace=True)
123
  self.sigmoid = nn.Sigmoid()
124
  self.init_weight()
 
136
  return feat_out
137
 
138
  def init_weight(self):
139
+ nn.init.kaiming_normal_(self.conv1.weight, a=1)
140
+ if self.conv1.bias is not None:
141
+ nn.init.constant_(self.conv1.bias, 0)
142
+ nn.init.kaiming_normal_(self.conv2.weight, a=1)
143
+ if self.conv2.bias is not None:
144
+ nn.init.constant_(self.conv2.bias, 0)
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  class BiSeNet(nn.Module):
148
+ def __init__(self, n_classes):
149
  super(BiSeNet, self).__init__()
150
  self.cp = ContextPath()
 
151
  self.ffm = FeatureFusionModule(256, 256)
152
  self.conv_out = BiSeNetOutput(256, 256, n_classes)
153
  self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
154
  self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
 
155
 
156
  def forward(self, x):
157
  H, W = x.size()[2:]
158
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x)
159
+ feat_sp = feat_res8 # Using res3b1 feature as spatial path feature
160
  feat_fuse = self.ffm(feat_sp, feat_cp8)
161
 
162
  feat_out = self.conv_out(feat_fuse)
 
172
  for ly in self.children():
173
  if isinstance(ly, nn.Conv2d):
174
  nn.init.kaiming_normal_(ly.weight, a=1)
175
+ if ly.bias is not None:
176
+ nn.init.constant_(ly.bias, 0)
177
 
178
  def get_params(self):
179
  wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
180
+ for name, child in