vrevar
Add application file
04c78c7
# Adapted from monodepth2
# https://github.com/nianticlabs/monodepth2/blob/master/networks/depth_decoder.py
#
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
from collections import OrderedDict
from easydict import EasyDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
class ConvBlock(torch.nn.Module):
"""Layer to perform a convolution followed by ELU."""
def __init__(self, in_channels, out_channels, bn=False, dropout=0.0):
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
Conv3x3(in_channels, out_channels),
nn.BatchNorm2d(out_channels) if bn else nn.Identity(),
nn.ELU(inplace=True),
# Pay attention: 2d version of dropout is used
nn.Dropout2d(dropout) if dropout > 0 else nn.Identity())
def forward(self, x):
out = self.block(x)
return out
class Conv3x3(nn.Module):
"""Layer to pad and convolve input with 3x3 kernels."""
def __init__(self, in_channels, out_channels, use_refl=True):
super(Conv3x3, self).__init__()
if use_refl:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
def forward(self, x):
out = self.pad(x)
out = self.conv(out)
return out
def upsample(x):
"""Upsample input tensor by a factor of 2."""
return F.interpolate(x, scale_factor=2, mode="nearest")
class ResNetMultiImageInput(models.ResNet):
"""Constructs a resnet model with varying number of input images.
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
def __init__(self, block, layers, num_classes=1000, in_channels=3):
super(ResNetMultiImageInput, self).__init__(block, layers)
self.inplanes = 64
self.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def resnet_multiimage_input(num_layers, pretrained=False, in_channels=3):
"""Constructs a ResNet model.
Args:
num_layers (int): Number of resnet layers. Must be 18 or 50
pretrained (bool): If True, returns a model pre-trained on ImageNet
in_channels (int): Number of input channels
"""
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
model = ResNetMultiImageInput(block_type, blocks, in_channels=in_channels)
if pretrained:
print('loading imagnet weights on resnet...')
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
# loaded['conv1.weight'] = torch.cat(
# (loaded['conv1.weight'], loaded['conv1.weight']), 1)
# diff = model.load_state_dict(loaded, strict=False)
return model
class ResnetEncoder(nn.Module):
"""Pytorch module for a resnet encoder
"""
def __init__(self, num_layers, pretrained, in_channels=3):
super(ResnetEncoder, self).__init__()
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
resnets = {18: models.resnet18,
34: models.resnet34,
50: models.resnet50,
101: models.resnet101,
152: models.resnet152}
if num_layers not in resnets:
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
if in_channels > 3:
self.encoder = resnet_multiimage_input(num_layers, pretrained, in_channels)
else:
weights = models.ResNet101_Weights.IMAGENET1K_V1 if pretrained else None
self.encoder = resnets[num_layers](weights=weights)
if num_layers > 34:
self.num_ch_enc[1:] *= 4
def forward(self, x):
self.features = []
# input_image, normals = xx
# x = (input_image - 0.45) / 0.225
# x = torch.cat((input_image, normals),1)
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))
return self.features
class Decoder(nn.Module):
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True,
kaiming_init=False, return_feats=False):
super().__init__()
self.num_output_channels = num_output_channels
self.use_skips = use_skips
self.upsample_mode = 'nearest'
self.scales = scales
self.return_feats = return_feats
self.num_ch_enc = num_ch_enc
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
# decoder
self.convs = OrderedDict()
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
# for s in self.scales:
self.convs[("dispconv", 0)] = Conv3x3(self.num_ch_dec[0], self.num_output_channels)
self.decoder = nn.ModuleList(list(self.convs.values()))
# self.sigmoid = nn.Sigmoid()
if kaiming_init:
print('init weights of decoder')
for m in self.children():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
def forward(self, input_features):
x = input_features[-1]
for i in range(4, -1, -1):
x = self.convs[("upconv", i, 0)](x)
x = [upsample(x)]
if self.use_skips and i > 0:
x += [input_features[i - 1]]
x = torch.cat(x, 1)
x = self.convs[("upconv", i, 1)](x)
# assert self.scales[0] == 0
final_conv = self.convs[("dispconv", 0)]
out = final_conv(x)
if self.return_feats:
return out, input_features[-1]
return out
class MultiHeadDecoder(nn.Module):
def __init__(self, num_ch_enc, tasks, return_feats, use_skips):
super().__init__()
self.decoders = nn.ModuleDict({k:
Decoder(num_ch_enc=num_ch_enc,
num_output_channels=num_ch,
scales=[0],
kaiming_init=False,
use_skips=use_skips,
return_feats=return_feats)
for k, num_ch in tasks.items()})
def forward(self, x):
y = EasyDict({k: v(x) for k, v in self.decoders.items()})
return y
class DenseMTL(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x):
return self.decoder(self.encoder(x))