PubAccount commited on
Commit
d68b77d
·
verified ·
1 Parent(s): 472dac7

Update networks/decoder.py

Browse files
Files changed (1) hide show
  1. networks/decoder.py +49 -1
networks/decoder.py CHANGED
@@ -5,7 +5,55 @@ import torch.nn as nn
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
 
8
- from mmseg.models.decode_heads.psp_head import PPM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class Decoder(nn.Module):
11
  def __init__(self,
 
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
 
8
+ class PPM(nn.ModuleList):
9
+ """Pooling Pyramid Module used in PSPNet.
10
+
11
+ Args:
12
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
13
+ Module.
14
+ in_channels (int): Input channels.
15
+ channels (int): Channels after modules, before conv_seg.
16
+ conv_cfg (dict|None): Config of conv layers.
17
+ norm_cfg (dict|None): Config of norm layers.
18
+ act_cfg (dict): Config of activation layers.
19
+ align_corners (bool): align_corners argument of F.interpolate.
20
+ """
21
+
22
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
23
+ act_cfg, align_corners, **kwargs):
24
+ super(PPM, self).__init__()
25
+ self.pool_scales = pool_scales
26
+ self.align_corners = align_corners
27
+ self.in_channels = in_channels
28
+ self.channels = channels
29
+ self.conv_cfg = conv_cfg
30
+ self.norm_cfg = norm_cfg
31
+ self.act_cfg = act_cfg
32
+ for pool_scale in pool_scales:
33
+ self.append(
34
+ nn.Sequential(
35
+ nn.AdaptiveAvgPool2d(pool_scale),
36
+ ConvModule(
37
+ self.in_channels,
38
+ self.channels,
39
+ 1,
40
+ conv_cfg=self.conv_cfg,
41
+ norm_cfg=self.norm_cfg,
42
+ act_cfg=self.act_cfg,
43
+ **kwargs)))
44
+
45
+ def forward(self, x):
46
+ """Forward function."""
47
+ ppm_outs = []
48
+ for ppm in self:
49
+ ppm_out = ppm(x)
50
+ upsampled_ppm_out = resize(
51
+ ppm_out,
52
+ size=x.size()[2:],
53
+ mode='bilinear',
54
+ align_corners=self.align_corners)
55
+ ppm_outs.append(upsampled_ppm_out)
56
+ return ppm_outs
57
 
58
  class Decoder(nn.Module):
59
  def __init__(self,