Spaces:
Runtime error
Runtime error
| from contextlib import ExitStack | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from isegm.model import ops | |
| from .basic_blocks import SeparableConv2d | |
| from .resnet import ResNetBackbone | |
| class DeepLabV3Plus(nn.Module): | |
| def __init__( | |
| self, | |
| backbone="resnet50", | |
| norm_layer=nn.BatchNorm2d, | |
| backbone_norm_layer=None, | |
| ch=256, | |
| project_dropout=0.5, | |
| inference_mode=False, | |
| **kwargs | |
| ): | |
| super(DeepLabV3Plus, self).__init__() | |
| if backbone_norm_layer is None: | |
| backbone_norm_layer = norm_layer | |
| self.backbone_name = backbone | |
| self.norm_layer = norm_layer | |
| self.backbone_norm_layer = backbone_norm_layer | |
| self.inference_mode = False | |
| self.ch = ch | |
| self.aspp_in_channels = 2048 | |
| self.skip_project_in_channels = 256 # layer 1 out_channels | |
| self._kwargs = kwargs | |
| if backbone == "resnet34": | |
| self.aspp_in_channels = 512 | |
| self.skip_project_in_channels = 64 | |
| self.backbone = ResNetBackbone( | |
| backbone=self.backbone_name, | |
| pretrained_base=False, | |
| norm_layer=self.backbone_norm_layer, | |
| **kwargs | |
| ) | |
| self.head = _DeepLabHead( | |
| in_channels=ch + 32, | |
| mid_channels=ch, | |
| out_channels=ch, | |
| norm_layer=self.norm_layer, | |
| ) | |
| self.skip_project = _SkipProject( | |
| self.skip_project_in_channels, 32, norm_layer=self.norm_layer | |
| ) | |
| self.aspp = _ASPP( | |
| in_channels=self.aspp_in_channels, | |
| atrous_rates=[12, 24, 36], | |
| out_channels=ch, | |
| project_dropout=project_dropout, | |
| norm_layer=self.norm_layer, | |
| ) | |
| if inference_mode: | |
| self.set_prediction_mode() | |
| def load_pretrained_weights(self): | |
| pretrained = ResNetBackbone( | |
| backbone=self.backbone_name, | |
| pretrained_base=True, | |
| norm_layer=self.backbone_norm_layer, | |
| **self._kwargs | |
| ) | |
| backbone_state_dict = self.backbone.state_dict() | |
| pretrained_state_dict = pretrained.state_dict() | |
| backbone_state_dict.update(pretrained_state_dict) | |
| self.backbone.load_state_dict(backbone_state_dict) | |
| if self.inference_mode: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| def set_prediction_mode(self): | |
| self.inference_mode = True | |
| self.eval() | |
| def forward(self, x, additional_features=None): | |
| with ExitStack() as stack: | |
| if self.inference_mode: | |
| stack.enter_context(torch.no_grad()) | |
| c1, _, c3, c4 = self.backbone(x, additional_features) | |
| c1 = self.skip_project(c1) | |
| x = self.aspp(c4) | |
| x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True) | |
| x = torch.cat((x, c1), dim=1) | |
| x = self.head(x) | |
| return (x,) | |
| class _SkipProject(nn.Module): | |
| def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): | |
| super(_SkipProject, self).__init__() | |
| _activation = ops.select_activation_function("relu") | |
| self.skip_project = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), | |
| norm_layer(out_channels), | |
| _activation(), | |
| ) | |
| def forward(self, x): | |
| return self.skip_project(x) | |
| class _DeepLabHead(nn.Module): | |
| def __init__( | |
| self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d | |
| ): | |
| super(_DeepLabHead, self).__init__() | |
| self.block = nn.Sequential( | |
| SeparableConv2d( | |
| in_channels=in_channels, | |
| out_channels=mid_channels, | |
| dw_kernel=3, | |
| dw_padding=1, | |
| activation="relu", | |
| norm_layer=norm_layer, | |
| ), | |
| SeparableConv2d( | |
| in_channels=mid_channels, | |
| out_channels=mid_channels, | |
| dw_kernel=3, | |
| dw_padding=1, | |
| activation="relu", | |
| norm_layer=norm_layer, | |
| ), | |
| nn.Conv2d( | |
| in_channels=mid_channels, out_channels=out_channels, kernel_size=1 | |
| ), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class _ASPP(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| atrous_rates, | |
| out_channels=256, | |
| project_dropout=0.5, | |
| norm_layer=nn.BatchNorm2d, | |
| ): | |
| super(_ASPP, self).__init__() | |
| b0 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| norm_layer(out_channels), | |
| nn.ReLU(), | |
| ) | |
| rate1, rate2, rate3 = tuple(atrous_rates) | |
| b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) | |
| b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) | |
| b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) | |
| b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) | |
| self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) | |
| project = [ | |
| nn.Conv2d( | |
| in_channels=5 * out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| norm_layer(out_channels), | |
| nn.ReLU(), | |
| ] | |
| if project_dropout > 0: | |
| project.append(nn.Dropout(project_dropout)) | |
| self.project = nn.Sequential(*project) | |
| def forward(self, x): | |
| x = torch.cat([block(x) for block in self.concurent], dim=1) | |
| return self.project(x) | |
| class _AsppPooling(nn.Module): | |
| def __init__(self, in_channels, out_channels, norm_layer): | |
| super(_AsppPooling, self).__init__() | |
| self.gap = nn.Sequential( | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| norm_layer(out_channels), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| pool = self.gap(x) | |
| return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True) | |
| def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): | |
| block = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| padding=atrous_rate, | |
| dilation=atrous_rate, | |
| bias=False, | |
| ), | |
| norm_layer(out_channels), | |
| nn.ReLU(), | |
| ) | |
| return block | |