Spaces:
Sleeping
Sleeping
Update networks/decoder.py
Browse files- networks/decoder.py +49 -1
networks/decoder.py
CHANGED
|
@@ -5,7 +5,55 @@ import torch.nn as nn
|
|
| 5 |
from torch import Tensor
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class Decoder(nn.Module):
|
| 11 |
def __init__(self,
|
|
|
|
| 5 |
from torch import Tensor
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
+
class PPM(nn.ModuleList):
|
| 9 |
+
"""Pooling Pyramid Module used in PSPNet.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 13 |
+
Module.
|
| 14 |
+
in_channels (int): Input channels.
|
| 15 |
+
channels (int): Channels after modules, before conv_seg.
|
| 16 |
+
conv_cfg (dict|None): Config of conv layers.
|
| 17 |
+
norm_cfg (dict|None): Config of norm layers.
|
| 18 |
+
act_cfg (dict): Config of activation layers.
|
| 19 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
| 23 |
+
act_cfg, align_corners, **kwargs):
|
| 24 |
+
super(PPM, self).__init__()
|
| 25 |
+
self.pool_scales = pool_scales
|
| 26 |
+
self.align_corners = align_corners
|
| 27 |
+
self.in_channels = in_channels
|
| 28 |
+
self.channels = channels
|
| 29 |
+
self.conv_cfg = conv_cfg
|
| 30 |
+
self.norm_cfg = norm_cfg
|
| 31 |
+
self.act_cfg = act_cfg
|
| 32 |
+
for pool_scale in pool_scales:
|
| 33 |
+
self.append(
|
| 34 |
+
nn.Sequential(
|
| 35 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
| 36 |
+
ConvModule(
|
| 37 |
+
self.in_channels,
|
| 38 |
+
self.channels,
|
| 39 |
+
1,
|
| 40 |
+
conv_cfg=self.conv_cfg,
|
| 41 |
+
norm_cfg=self.norm_cfg,
|
| 42 |
+
act_cfg=self.act_cfg,
|
| 43 |
+
**kwargs)))
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
"""Forward function."""
|
| 47 |
+
ppm_outs = []
|
| 48 |
+
for ppm in self:
|
| 49 |
+
ppm_out = ppm(x)
|
| 50 |
+
upsampled_ppm_out = resize(
|
| 51 |
+
ppm_out,
|
| 52 |
+
size=x.size()[2:],
|
| 53 |
+
mode='bilinear',
|
| 54 |
+
align_corners=self.align_corners)
|
| 55 |
+
ppm_outs.append(upsampled_ppm_out)
|
| 56 |
+
return ppm_outs
|
| 57 |
|
| 58 |
class Decoder(nn.Module):
|
| 59 |
def __init__(self,
|