| """ |
| @author: |
| @Date: 2021/07/17 |
| @description: Use the feature extractor proposed by HorizonNet |
| """ |
|
|
| import numpy as np |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.models as models |
| import functools |
| from models.base_model import BaseModule |
|
|
| ENCODER_RESNET = [ |
| 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', |
| 'resnext50_32x4d', 'resnext101_32x8d' |
| ] |
| ENCODER_DENSENET = [ |
| 'densenet121', 'densenet169', 'densenet161', 'densenet201' |
| ] |
|
|
|
|
| def lr_pad(x, padding=1): |
| ''' Pad left/right-most to each other instead of zero padding ''' |
| return torch.cat([x[..., -padding:], x, x[..., :padding]], dim=3) |
|
|
|
|
| class LR_PAD(nn.Module): |
| ''' Pad left/right-most to each other instead of zero padding ''' |
|
|
| def __init__(self, padding=1): |
| super(LR_PAD, self).__init__() |
| self.padding = padding |
|
|
| def forward(self, x): |
| return lr_pad(x, self.padding) |
|
|
|
|
| def wrap_lr_pad(net): |
| for name, m in net.named_modules(): |
| if not isinstance(m, nn.Conv2d): |
| continue |
| if m.padding[1] == 0: |
| continue |
| w_pad = int(m.padding[1]) |
| m.padding = (m.padding[0], 0) |
| names = name.split('.') |
|
|
| root = functools.reduce(lambda o, i: getattr(o, i), [net] + names[:-1]) |
| setattr( |
| root, names[-1], |
| nn.Sequential(LR_PAD(w_pad), m) |
| ) |
|
|
|
|
| ''' |
| Encoder |
| ''' |
|
|
|
|
| class Resnet(nn.Module): |
| def __init__(self, backbone='resnet50', pretrained=True): |
| super(Resnet, self).__init__() |
| assert backbone in ENCODER_RESNET |
| self.encoder = getattr(models, backbone)(pretrained=pretrained) |
| del self.encoder.fc, self.encoder.avgpool |
|
|
| def forward(self, x): |
| features = [] |
| x = self.encoder.conv1(x) |
| x = self.encoder.bn1(x) |
| x = self.encoder.relu(x) |
| x = self.encoder.maxpool(x) |
|
|
| x = self.encoder.layer1(x) |
| features.append(x) |
| x = self.encoder.layer2(x) |
| features.append(x) |
| x = self.encoder.layer3(x) |
| features.append(x) |
| x = self.encoder.layer4(x) |
| features.append(x) |
| return features |
|
|
| def list_blocks(self): |
| lst = [m for m in self.encoder.children()] |
| block0 = lst[:4] |
| block1 = lst[4:5] |
| block2 = lst[5:6] |
| block3 = lst[6:7] |
| block4 = lst[7:8] |
| return block0, block1, block2, block3, block4 |
|
|
|
|
| class Densenet(nn.Module): |
| def __init__(self, backbone='densenet169', pretrained=True): |
| super(Densenet, self).__init__() |
| assert backbone in ENCODER_DENSENET |
| self.encoder = getattr(models, backbone)(pretrained=pretrained) |
| self.final_relu = nn.ReLU(inplace=True) |
| del self.encoder.classifier |
|
|
| def forward(self, x): |
| lst = [] |
| for m in self.encoder.features.children(): |
| x = m(x) |
| lst.append(x) |
| features = [lst[4], lst[6], lst[8], self.final_relu(lst[11])] |
| return features |
|
|
| def list_blocks(self): |
| lst = [m for m in self.encoder.features.children()] |
| block0 = lst[:4] |
| block1 = lst[4:6] |
| block2 = lst[6:8] |
| block3 = lst[8:10] |
| block4 = lst[10:] |
| return block0, block1, block2, block3, block4 |
|
|
|
|
| ''' |
| Decoder |
| ''' |
|
|
|
|
| class ConvCompressH(nn.Module): |
| ''' Reduce feature height by factor of two ''' |
|
|
| def __init__(self, in_c, out_c, ks=3): |
| super(ConvCompressH, self).__init__() |
| assert ks % 2 == 1 |
| self.layers = nn.Sequential( |
| nn.Conv2d(in_c, out_c, kernel_size=ks, stride=(2, 1), padding=ks // 2), |
| nn.BatchNorm2d(out_c), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| class GlobalHeightConv(nn.Module): |
| def __init__(self, in_c, out_c): |
| super(GlobalHeightConv, self).__init__() |
| self.layer = nn.Sequential( |
| ConvCompressH(in_c, in_c // 2), |
| ConvCompressH(in_c // 2, in_c // 2), |
| ConvCompressH(in_c // 2, in_c // 4), |
| ConvCompressH(in_c // 4, out_c), |
| ) |
|
|
| def forward(self, x, out_w): |
| x = self.layer(x) |
|
|
| factor = out_w // x.shape[3] |
| x = torch.cat([x[..., -1:], x, x[..., :1]], 3) |
| d_type = x.dtype |
| x = F.interpolate(x, size=(x.shape[2], out_w + 2 * factor), mode='bilinear', align_corners=False) |
| |
| |
| x = x[..., factor:-factor] |
| return x |
|
|
|
|
| class GlobalHeightStage(nn.Module): |
| def __init__(self, c1, c2, c3, c4, out_scale=8): |
| ''' Process 4 blocks from encoder to single multiscale features ''' |
| super(GlobalHeightStage, self).__init__() |
| self.cs = c1, c2, c3, c4 |
| self.out_scale = out_scale |
| self.ghc_lst = nn.ModuleList([ |
| GlobalHeightConv(c1, c1 // out_scale), |
| GlobalHeightConv(c2, c2 // out_scale), |
| GlobalHeightConv(c3, c3 // out_scale), |
| GlobalHeightConv(c4, c4 // out_scale), |
| ]) |
|
|
| def forward(self, conv_list, out_w): |
| assert len(conv_list) == 4 |
| bs = conv_list[0].shape[0] |
| feature = torch.cat([ |
| f(x, out_w).reshape(bs, -1, out_w) |
| for f, x, out_c in zip(self.ghc_lst, conv_list, self.cs) |
| ], dim=1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return feature |
|
|
|
|
| class HorizonNetFeatureExtractor(nn.Module): |
| x_mean = torch.FloatTensor(np.array([0.485, 0.456, 0.406])[None, :, None, None]) |
| x_std = torch.FloatTensor(np.array([0.229, 0.224, 0.225])[None, :, None, None]) |
|
|
| def __init__(self, backbone='resnet50'): |
| super(HorizonNetFeatureExtractor, self).__init__() |
| self.out_scale = 8 |
| self.step_cols = 4 |
|
|
| |
| if backbone.startswith('res'): |
| self.feature_extractor = Resnet(backbone, pretrained=True) |
| elif backbone.startswith('dense'): |
| self.feature_extractor = Densenet(backbone, pretrained=True) |
| else: |
| raise NotImplementedError() |
|
|
| |
| with torch.no_grad(): |
| dummy = torch.zeros(1, 3, 512, 1024) |
| c1, c2, c3, c4 = [b.shape[1] for b in self.feature_extractor(dummy)] |
| self.c_last = (c1 * 8 + c2 * 4 + c3 * 2 + c4 * 1) // self.out_scale |
|
|
| |
| self.reduce_height_module = GlobalHeightStage(c1, c2, c3, c4, self.out_scale) |
| self.x_mean.requires_grad = False |
| self.x_std.requires_grad = False |
| wrap_lr_pad(self) |
|
|
| def _prepare_x(self, x): |
| x = x.clone() |
| if self.x_mean.device != x.device: |
| self.x_mean = self.x_mean.to(x.device) |
| self.x_std = self.x_std.to(x.device) |
| x[:, :3] = (x[:, :3] - self.x_mean) / self.x_std |
|
|
| return x |
|
|
| def forward(self, x): |
| |
| x = self._prepare_x(x) |
| conv_list = self.feature_extractor(x) |
| |
| |
| |
| |
| |
| x = self.reduce_height_module(conv_list, x.shape[3] // self.step_cols) |
| |
| |
| |
| |
| |
| |
| return x |
|
|
|
|
| if __name__ == '__main__': |
| from PIL import Image |
| extractor = HorizonNetFeatureExtractor() |
| img = np.array(Image.open("../../src/demo.png")).transpose((2, 0, 1)) |
| input = torch.Tensor([img]) |
| feature = extractor(input) |
| print(feature.shape) |
|
|