| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from .utils import _SimpleSegmentationModel |
| | from ._deeplab import ASPPConv, ASPPPooling, ASPP, AtrousSeparableConvolution |
| | from .enhanced_modules import EOANetModule |
| |
|
| |
|
| | class EnhancedDeepLabV3(_SimpleSegmentationModel): |
| | """ |
| | Implements Enhanced DeepLabV3 model with Normalized Multi-Scale Attention and Entropy-Optimized Gating. |
| | """ |
| | pass |
| |
|
| |
|
| | class EnhancedDeepLabHeadV3Plus(nn.Module): |
| | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36], |
| | use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
| | super(EnhancedDeepLabHeadV3Plus, self).__init__() |
| | self.use_eoaNet = use_eoaNet |
| | |
| | self.project = nn.Sequential( |
| | nn.Conv2d(low_level_channels, 48, 1, bias=False), |
| | nn.BatchNorm2d(48), |
| | nn.ReLU(inplace=True), |
| | ) |
| |
|
| | self.aspp = ASPP(in_channels, aspp_dilate) |
| | |
| | |
| | if self.use_eoaNet: |
| | self.eoaNet = EOANetModule(256, scales=msa_scales, beta=eog_beta) |
| | |
| | self.classifier = nn.Sequential( |
| | nn.Conv2d(304, 256, 3, padding=1, bias=False), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(256, num_classes, 1) |
| | ) |
| | self._init_weight() |
| |
|
| | def forward(self, feature): |
| | low_level_feature = self.project(feature['low_level']) |
| | output_feature = self.aspp(feature['out']) |
| | |
| | |
| | if self.use_eoaNet: |
| | output_feature = self.eoaNet(output_feature) |
| | |
| | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) |
| | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) |
| | |
| | def _init_weight(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| |
|
| | class EnhancedDeepLabHead(nn.Module): |
| | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36], |
| | use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
| | super(EnhancedDeepLabHead, self).__init__() |
| | self.use_eoaNet = use_eoaNet |
| |
|
| | self.aspp = ASPP(in_channels, aspp_dilate) |
| | |
| | |
| | if self.use_eoaNet: |
| | self.eoaNet = EOANetModule(256, scales=msa_scales, beta=eog_beta) |
| | |
| | self.classifier = nn.Sequential( |
| | nn.Conv2d(256, 256, 3, padding=1, bias=False), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(256, num_classes, 1) |
| | ) |
| | self._init_weight() |
| |
|
| | def forward(self, feature): |
| | output = self.aspp(feature['out']) |
| | |
| | |
| | if self.use_eoaNet: |
| | output = self.eoaNet(output) |
| | |
| | return self.classifier(output) |
| |
|
| | def _init_weight(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight) |
| | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| |
|
| | def convert_to_separable_conv(module): |
| | new_module = module |
| | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: |
| | new_module = AtrousSeparableConvolution(module.in_channels, |
| | module.out_channels, |
| | module.kernel_size, |
| | module.stride, |
| | module.padding, |
| | module.dilation, |
| | module.bias) |
| | for name, child in module.named_children(): |
| | new_module.add_module(name, convert_to_separable_conv(child)) |
| | return new_module |