Spaces:
Running
Running
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
| class Decoder(nn.Module): | |
| def __init__(self, num_classes, backbone, BatchNorm): | |
| super(Decoder, self).__init__() | |
| if backbone == 'resnet' or backbone == 'drn': | |
| low_level_inplanes = 256 | |
| elif backbone == 'xception': | |
| low_level_inplanes = 128 | |
| elif backbone == 'mobilenet': | |
| low_level_inplanes = 24 | |
| else: | |
| raise NotImplementedError | |
| self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) | |
| self.bn1 = BatchNorm(48) | |
| self.relu = nn.ReLU() | |
| self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), | |
| BatchNorm(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), | |
| BatchNorm(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Conv2d(256, num_classes, kernel_size=1, stride=1), | |
| nn.Sigmoid() | |
| ) | |
| self._init_weight() | |
| def forward(self, x, low_level_feat): | |
| low_level_feat = self.conv1(low_level_feat) | |
| low_level_feat = self.bn1(low_level_feat) | |
| low_level_feat = self.relu(low_level_feat) | |
| x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) | |
| x = torch.cat((x, low_level_feat), dim=1) | |
| x = self.last_conv(x) | |
| return x | |
| def _init_weight(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| torch.nn.init.kaiming_normal_(m.weight) | |
| elif isinstance(m, SynchronizedBatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def build_decoder(num_classes, backbone, BatchNorm): | |
| return Decoder(num_classes, backbone, BatchNorm) |