Spaces:
Sleeping
Sleeping
| """ | |
| Backbones supported by torchvison. | |
| """ | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| class TVDeeplabRes101Encoder(nn.Module): | |
| """ | |
| FCN-Resnet101 backbone from torchvision deeplabv3 | |
| No ASPP is used as we found emperically it hurts performance | |
| """ | |
| def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False): | |
| super().__init__() | |
| _model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None) | |
| if use_coco_init: | |
| print("###### NETWORK: Using ms-coco initialization ######") | |
| else: | |
| print("###### NETWORK: Training from scratch ######") | |
| _model_list = list(_model.children()) | |
| self.aux_dim_keep = aux_dim_keep | |
| self.backbone = _model_list[0] | |
| self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension | |
| self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False) | |
| _aspp = _model_list[1][0] | |
| _conv256 = _model_list[1][1] | |
| self.aspp_out = nn.Sequential(*[_aspp, _conv256] ) | |
| self.use_aspp = use_aspp | |
| def forward(self, x_in, low_level): | |
| """ | |
| Args: | |
| low_level: whether returning aggregated low-level features in FCN | |
| """ | |
| fts = self.backbone(x_in) | |
| if self.use_aspp: | |
| fts256 = self.aspp_out(fts['out']) | |
| high_level_fts = fts256 | |
| else: | |
| fts2048 = fts['out'] | |
| high_level_fts = self.localconv(fts2048) | |
| if low_level: | |
| low_level_fts = fts['aux'][:, : self.aux_dim_keep] | |
| return high_level_fts, low_level_fts | |
| else: | |
| return high_level_fts | |