|
|
|
|
|
|
| """
|
| @Author : Peike Li
|
| @Contact : peike.li@yahoo.com
|
| @File : psp.py
|
| @Time : 8/4/19 3:36 PM
|
| @Desc :
|
| @License : This source code is licensed under the license found in the
|
| LICENSE file in the root directory of this source tree.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
|
|
| from modules import InPlaceABNSync
|
|
|
|
|
| class PSPModule(nn.Module):
|
| """
|
| Reference:
|
| Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
|
| """
|
| def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
|
| super(PSPModule, self).__init__()
|
|
|
| self.stages = []
|
| self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
|
| self.bottleneck = nn.Sequential(
|
| nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
|
| bias=False),
|
| InPlaceABNSync(out_features),
|
| )
|
|
|
| def _make_stage(self, features, out_features, size):
|
| prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
|
| conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
|
| bn = InPlaceABNSync(out_features)
|
| return nn.Sequential(prior, conv, bn)
|
|
|
| def forward(self, feats):
|
| h, w = feats.size(2), feats.size(3)
|
| priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in
|
| self.stages] + [feats]
|
| bottle = self.bottleneck(torch.cat(priors, 1))
|
| return bottle |