Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import math | |
| from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
| import torch.utils.model_zoo as model_zoo | |
| def conv_bn(inp, oup, stride, BatchNorm): | |
| return nn.Sequential( | |
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |
| BatchNorm(oup), | |
| nn.ReLU6(inplace=True) | |
| ) | |
| def fixed_padding(inputs, kernel_size, dilation): | |
| kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) | |
| pad_total = kernel_size_effective - 1 | |
| pad_beg = pad_total // 2 | |
| pad_end = pad_total - pad_beg | |
| padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) | |
| return padded_inputs | |
| class InvertedResidual(nn.Module): | |
| def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): | |
| super(InvertedResidual, self).__init__() | |
| self.stride = stride | |
| assert stride in [1, 2] | |
| hidden_dim = round(inp * expand_ratio) | |
| self.use_res_connect = self.stride == 1 and inp == oup | |
| self.kernel_size = 3 | |
| self.dilation = dilation | |
| if expand_ratio == 1: | |
| self.conv = nn.Sequential( | |
| # dw | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), | |
| BatchNorm(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # pw-linear | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), | |
| BatchNorm(oup), | |
| ) | |
| else: | |
| self.conv = nn.Sequential( | |
| # pw | |
| nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), | |
| BatchNorm(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # dw | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), | |
| BatchNorm(hidden_dim), | |
| nn.ReLU6(inplace=True), | |
| # pw-linear | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), | |
| BatchNorm(oup), | |
| ) | |
| def forward(self, x): | |
| x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) | |
| if self.use_res_connect: | |
| x = x + self.conv(x_pad) | |
| else: | |
| x = self.conv(x_pad) | |
| return x | |
| class MobileNetV2(nn.Module): | |
| def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): | |
| super(MobileNetV2, self).__init__() | |
| block = InvertedResidual | |
| input_channel = 32 | |
| current_stride = 1 | |
| rate = 1 | |
| interverted_residual_setting = [ | |
| # t, c, n, s | |
| [1, 16, 1, 1], | |
| [6, 24, 2, 2], | |
| [6, 32, 3, 2], | |
| [6, 64, 4, 2], | |
| [6, 96, 3, 1], | |
| [6, 160, 3, 2], | |
| [6, 320, 1, 1], | |
| ] | |
| # building first layer | |
| input_channel = int(input_channel * width_mult) | |
| self.features = [conv_bn(3, input_channel, 2, BatchNorm)] | |
| current_stride *= 2 | |
| # building inverted residual blocks | |
| for t, c, n, s in interverted_residual_setting: | |
| if current_stride == output_stride: | |
| stride = 1 | |
| dilation = rate | |
| rate *= s | |
| else: | |
| stride = s | |
| dilation = 1 | |
| current_stride *= s | |
| output_channel = int(c * width_mult) | |
| for i in range(n): | |
| if i == 0: | |
| self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) | |
| else: | |
| self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) | |
| input_channel = output_channel | |
| self.features = nn.Sequential(*self.features) | |
| self._initialize_weights() | |
| if pretrained: | |
| self._load_pretrained_model() | |
| self.low_level_features = self.features[0:4] | |
| self.high_level_features = self.features[4:] | |
| def forward(self, x): | |
| low_level_feat = self.low_level_features(x) | |
| x = self.high_level_features(low_level_feat) | |
| return x, low_level_feat | |
| def _load_pretrained_model(self): | |
| pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') | |
| model_dict = {} | |
| state_dict = self.state_dict() | |
| for k, v in pretrain_dict.items(): | |
| if k in state_dict: | |
| model_dict[k] = v | |
| state_dict.update(model_dict) | |
| self.load_state_dict(state_dict) | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| # m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| torch.nn.init.kaiming_normal_(m.weight) | |
| elif isinstance(m, SynchronizedBatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| if __name__ == "__main__": | |
| input = torch.rand(1, 3, 512, 512) | |
| model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) | |
| output, low_level_feat = model(input) | |
| print(output.size()) | |
| print(low_level_feat.size()) | |