| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | |
| | from torch.nn import SyncBatchNorm as BatchNorm2d |
| | from functools import partial |
| | import re |
| | from models.base_models.resnet import resnet101, resnet18, resnet50 |
| | from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp |
| |
|
| | class DeepLabV3p_r18(nn.Module): |
| | def __init__(self, num_classes, config): |
| | super(DeepLabV3p_r18, self).__init__() |
| | self.norm_layer = BatchNorm2d |
| | self.backbone = resnet18(config.pretrained_model_r18, norm_layer=self.norm_layer, |
| | bn_eps=config.bn_eps, |
| | bn_momentum=config.bn_momentum, |
| | deep_stem=False, stem_width=64) |
| | self.dilate = 2 |
| | for m in self.backbone.layer4.children(): |
| | m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
| | self.dilate *= 2 |
| |
|
| | self.head = Head('r18', num_classes, self.norm_layer, config.bn_momentum) |
| | self.business_layer = [] |
| | self.business_layer.append(self.head) |
| |
|
| | self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
| | self.business_layer.append(self.classifier) |
| | init_weight(self.business_layer, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| | init_weight(self.classifier, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| |
|
| | def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
| | data = data[0] |
| | blocks = self.backbone(data) |
| | v3plus_feature = self.head(blocks) |
| | b, c, h, w = v3plus_feature.shape |
| |
|
| | pred = self.classifier(v3plus_feature) |
| |
|
| | b, c, h, w = data.shape |
| | pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
| | if not self.training: |
| | return pred |
| | else: |
| | if get_sup_loss: |
| | return pred, self.get_sup_loss(pred, gt, criterion) |
| | else: |
| | return pred |
| | |
| | def get_sup_loss(self, pred, gt, criterion): |
| | pred = pred[:gt.shape[0]] |
| | return criterion(pred, gt) |
| |
|
| | |
| | def _nostride_dilate(self, m, dilate): |
| | if isinstance(m, nn.Conv2d): |
| | if m.stride == (2, 2): |
| | m.stride = (1, 1) |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| | else: |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| | def get_params(self): |
| | param_groups = [[], [], []] |
| | enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
| | param_groups[0].extend(enc) |
| | param_groups[1].extend(enc_no_decay) |
| | dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
| | param_groups[2].extend(dec) |
| | param_groups[1].extend(dec_no_decay) |
| | classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
| | param_groups[2].extend(classifier) |
| | param_groups[1].extend(classifier_no_decay) |
| | return param_groups |
| |
|
| | class DeepLabV3p_r50(nn.Module): |
| | def __init__(self, num_classes, config): |
| | super(DeepLabV3p_r50, self).__init__() |
| | self.norm_layer = BatchNorm2d |
| | self.backbone = resnet50(config.pretrained_model_r50, norm_layer=self.norm_layer, |
| | bn_eps=config.bn_eps, |
| | bn_momentum=config.bn_momentum, |
| | deep_stem=True, stem_width=64) |
| | self.dilate = 2 |
| | for m in self.backbone.layer4.children(): |
| | m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
| | self.dilate *= 2 |
| |
|
| | self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) |
| | self.business_layer = [] |
| | self.business_layer.append(self.head) |
| |
|
| | self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
| | self.business_layer.append(self.classifier) |
| | init_weight(self.business_layer, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| | init_weight(self.classifier, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| |
|
| |
|
| | def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
| | data = data[0] |
| | blocks = self.backbone(data) |
| | v3plus_feature = self.head(blocks) |
| | b, c, h, w = v3plus_feature.shape |
| |
|
| | pred = self.classifier(v3plus_feature) |
| |
|
| | b, c, h, w = data.shape |
| | pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
| | if not self.training: |
| | return pred |
| | else: |
| | if get_sup_loss: |
| | return pred, self.get_sup_loss(pred, gt, criterion) |
| | else: |
| | return pred |
| | |
| | def get_sup_loss(self, pred, gt, criterion): |
| | pred = pred[:gt.shape[0]] |
| | return criterion(pred, gt) |
| |
|
| | def get_params(self): |
| | param_groups = [[], [], []] |
| | enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
| | param_groups[0].extend(enc) |
| | param_groups[1].extend(enc_no_decay) |
| | dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
| | param_groups[2].extend(dec) |
| | param_groups[1].extend(dec_no_decay) |
| | classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
| | param_groups[2].extend(classifier) |
| | param_groups[1].extend(classifier_no_decay) |
| | return param_groups |
| | |
| | |
| | def _nostride_dilate(self, m, dilate): |
| | if isinstance(m, nn.Conv2d): |
| | if m.stride == (2, 2): |
| | m.stride = (1, 1) |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| | else: |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| | class DeepLabV3p_r101(nn.Module): |
| | def __init__(self, num_classes, config): |
| | super(DeepLabV3p_r101, self).__init__() |
| | self.norm_layer = BatchNorm2d |
| | self.backbone = resnet101(config.pretrained_model_r101, norm_layer=self.norm_layer, |
| | bn_eps=config.bn_eps, |
| | bn_momentum=config.bn_momentum, |
| | deep_stem=True, stem_width=64) |
| | self.dilate = 2 |
| | for m in self.backbone.layer4.children(): |
| | m.apply(partial(self._nostride_dilate, dilate=self.dilate)) |
| | self.dilate *= 2 |
| |
|
| | self.head = Head('r50', num_classes, self.norm_layer, config.bn_momentum) |
| | self.business_layer = [] |
| | self.business_layer.append(self.head) |
| |
|
| | self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True) |
| | self.business_layer.append(self.classifier) |
| | init_weight(self.business_layer, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| | init_weight(self.classifier, nn.init.kaiming_normal_, |
| | BatchNorm2d, config.bn_eps, config.bn_momentum, |
| | mode='fan_in', nonlinearity='relu') |
| |
|
| |
|
| | def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
| | data = data[0] |
| | blocks = self.backbone(data) |
| | v3plus_feature = self.head(blocks) |
| | b, c, h, w = v3plus_feature.shape |
| |
|
| | pred = self.classifier(v3plus_feature) |
| |
|
| | b, c, h, w = data.shape |
| | pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True) |
| | if not self.training: |
| | return pred |
| | else: |
| | if get_sup_loss: |
| | return pred, self.get_sup_loss(pred, gt, criterion) |
| | else: |
| | return pred |
| | |
| | def get_sup_loss(self, pred, gt, criterion): |
| | pred = pred[:gt.shape[0]] |
| | return criterion(pred, gt) |
| |
|
| | def get_params(self): |
| | param_groups = [[], [], []] |
| | enc, enc_no_decay = group_weight(self.backbone, self.norm_layer) |
| | param_groups[0].extend(enc) |
| | param_groups[1].extend(enc_no_decay) |
| | dec, dec_no_decay = group_weight(self.head, self.norm_layer) |
| | param_groups[2].extend(dec) |
| | param_groups[1].extend(dec_no_decay) |
| | classifier, classifier_no_decay = group_weight(self.classifier, self.norm_layer) |
| | param_groups[2].extend(classifier) |
| | param_groups[1].extend(classifier_no_decay) |
| | return param_groups |
| | |
| | |
| | def _nostride_dilate(self, m, dilate): |
| | if isinstance(m, nn.Conv2d): |
| | if m.stride == (2, 2): |
| | m.stride = (1, 1) |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| | else: |
| | if m.kernel_size == (3, 3): |
| | m.dilation = (dilate, dilate) |
| | m.padding = (dilate, dilate) |
| |
|
| |
|
| | class ASPP(nn.Module): |
| | def __init__(self, |
| | in_channels, |
| | out_channels, |
| | dilation_rates=(12, 24, 36), |
| | hidden_channels=256, |
| | norm_act=nn.BatchNorm2d, |
| | pooling_size=None): |
| | super(ASPP, self).__init__() |
| | self.pooling_size = pooling_size |
| |
|
| | self.map_convs = nn.ModuleList([ |
| | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), |
| | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[0], |
| | padding=dilation_rates[0]), |
| | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[1], |
| | padding=dilation_rates[1]), |
| | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[2], |
| | padding=dilation_rates[2]) |
| | ]) |
| | self.map_bn = norm_act(hidden_channels * 4) |
| |
|
| | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) |
| | self.global_pooling_bn = norm_act(hidden_channels) |
| |
|
| | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) |
| | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) |
| | self.red_bn = norm_act(out_channels) |
| |
|
| | self.leak_relu = nn.LeakyReLU() |
| |
|
| | def forward(self, x): |
| | |
| | out = torch.cat([m(x) for m in self.map_convs], dim=1) |
| | out = self.map_bn(out) |
| | out = self.leak_relu(out) |
| | out = self.red_conv(out) |
| |
|
| | |
| | pool = self._global_pooling(x) |
| | pool = self.global_pooling_conv(pool) |
| | pool = self.global_pooling_bn(pool) |
| |
|
| | pool = self.leak_relu(pool) |
| |
|
| | pool = self.pool_red_conv(pool) |
| | if self.training or self.pooling_size is None: |
| | pool = pool.repeat(1, 1, x.size(2), x.size(3)) |
| |
|
| | out += pool |
| | out = self.red_bn(out) |
| | out = self.leak_relu(out) |
| | return out |
| |
|
| | def _global_pooling(self, x): |
| | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) |
| | pool = pool.view(x.size(0), x.size(1), 1, 1) |
| | return pool |
| |
|
| |
|
| | class Head(nn.Module): |
| | def __init__(self, base_model, classify_classes, norm_act=nn.BatchNorm2d, bn_momentum=0.0003): |
| | super(Head, self).__init__() |
| |
|
| | self.classify_classes = classify_classes |
| | if base_model == 'r18': |
| | self.aspp = ASPP(512, 256, [6, 12, 18], norm_act=norm_act) |
| | |
| | self.reduce = nn.Sequential( |
| | nn.Conv2d(64, 48, 1, bias=False), |
| | norm_act(48, momentum=bn_momentum), |
| | nn.ReLU(), |
| | ) |
| | elif base_model == 'r50': |
| | self.aspp = ASPP(2048, 256, [6, 12, 18], norm_act=norm_act) |
| | self.reduce = nn.Sequential( |
| | nn.Conv2d(256, 48, 1, bias=False), |
| | norm_act(48, momentum=bn_momentum), |
| | nn.ReLU(), |
| | ) |
| | else: |
| | raise Exception(f"Head not implemented for {base_model}") |
| |
|
| |
|
| | |
| | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), |
| | norm_act(256, momentum=bn_momentum), |
| | nn.ReLU(), |
| | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), |
| | norm_act(256, momentum=bn_momentum), |
| | nn.ReLU(), |
| | ) |
| |
|
| | def forward(self, f_list): |
| | f = f_list[-1] |
| | f = self.aspp(f) |
| |
|
| | low_level_features = f_list[0] |
| | low_h, low_w = low_level_features.size(2), low_level_features.size(3) |
| | low_level_features = self.reduce(low_level_features) |
| |
|
| | f = F.interpolate(f, size=(low_h, low_w), mode='bilinear', align_corners=True) |
| | f = torch.cat((f, low_level_features), dim=1) |
| | f = self.last_conv(f) |
| |
|
| | return f |
| | |
| |
|
| | def group_weight(module, norm_layer): |
| | group_decay = [] |
| | group_no_decay = [] |
| | for m in module.modules(): |
| | if isinstance(m, nn.Linear): |
| | group_decay.append(m.weight) |
| | if m.bias is not None: |
| | group_no_decay.append(m.bias) |
| | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): |
| | group_decay.append(m.weight) |
| | if m.bias is not None: |
| | group_no_decay.append(m.bias) |
| | elif isinstance(m, Conv2_5D_depth): |
| | group_decay.append(m.weight_0) |
| | group_decay.append(m.weight_1) |
| | group_decay.append(m.weight_2) |
| | if m.bias is not None: |
| | group_no_decay.append(m.bias) |
| | elif isinstance(m, Conv2_5D_disp): |
| | group_decay.append(m.weight_0) |
| | group_decay.append(m.weight_1) |
| | group_decay.append(m.weight_2) |
| | if m.bias is not None: |
| | group_no_decay.append(m.bias) |
| | elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ |
| | or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): |
| | if m.weight is not None: |
| | group_no_decay.append(m.weight) |
| | if m.bias is not None: |
| | group_no_decay.append(m.bias) |
| | elif isinstance(m, nn.Parameter): |
| | group_decay.append(m) |
| | elif isinstance(m, nn.Embedding): |
| | group_decay.append(m) |
| | assert len(list(module.parameters())) == len(group_decay) + len( |
| | group_no_decay) |
| | return group_decay, group_no_decay |
| |
|
| | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
| | **kwargs): |
| | for name, m in feature.named_modules(): |
| | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | conv_init(m.weight, **kwargs) |
| | elif isinstance(m, Conv2_5D_depth): |
| | conv_init(m.weight_0, **kwargs) |
| | conv_init(m.weight_1, **kwargs) |
| | conv_init(m.weight_2, **kwargs) |
| | elif isinstance(m, Conv2_5D_disp): |
| | conv_init(m.weight_0, **kwargs) |
| | conv_init(m.weight_1, **kwargs) |
| | conv_init(m.weight_2, **kwargs) |
| | elif isinstance(m, norm_layer): |
| | m.eps = bn_eps |
| | m.momentum = bn_momentum |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| |
|
| | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
| | **kwargs): |
| | if isinstance(module_list, list): |
| | for feature in module_list: |
| | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
| | **kwargs) |
| | else: |
| | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
| | **kwargs) |