Mr7Explorer commited on
Commit
6636ac4
·
verified ·
1 Parent(s): 44c3d89

Create decoder_blocks.py

Browse files
Files changed (1) hide show
  1. decoder_blocks.py +65 -0
decoder_blocks.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.modules.aspp import ASPP, ASPPDeformable
4
+ from config import Config
5
+
6
+
7
+ config = Config()
8
+
9
+
10
+ class BasicDecBlk(nn.Module):
11
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
12
+ super(BasicDecBlk, self).__init__()
13
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
14
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
15
+ self.relu_in = nn.ReLU(inplace=True)
16
+ if config.dec_att == 'ASPP':
17
+ self.dec_att = ASPP(in_channels=inter_channels)
18
+ elif config.dec_att == 'ASPPDeformable':
19
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
20
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
21
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
22
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
23
+
24
+ def forward(self, x):
25
+ x = self.conv_in(x)
26
+ x = self.bn_in(x)
27
+ x = self.relu_in(x)
28
+ if hasattr(self, 'dec_att'):
29
+ x = self.dec_att(x)
30
+ x = self.conv_out(x)
31
+ x = self.bn_out(x)
32
+ return x
33
+
34
+
35
+ class ResBlk(nn.Module):
36
+ def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
37
+ super(ResBlk, self).__init__()
38
+ if out_channels is None:
39
+ out_channels = in_channels
40
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
41
+
42
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
43
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
44
+ self.relu_in = nn.ReLU(inplace=True)
45
+
46
+ if config.dec_att == 'ASPP':
47
+ self.dec_att = ASPP(in_channels=inter_channels)
48
+ elif config.dec_att == 'ASPPDeformable':
49
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
50
+
51
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
52
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
53
+
54
+ self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
55
+
56
+ def forward(self, x):
57
+ _x = self.conv_resi(x)
58
+ x = self.conv_in(x)
59
+ x = self.bn_in(x)
60
+ x = self.relu_in(x)
61
+ if hasattr(self, 'dec_att'):
62
+ x = self.dec_att(x)
63
+ x = self.conv_out(x)
64
+ x = self.bn_out(x)
65
+ return x + _x