|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.init as init |
|
|
import torch.nn.functional as F |
|
|
from torch.distributions import Beta |
|
|
|
|
|
from core.extractor import ResidualBlock |
|
|
from core.confidence import EfficientUNetSimple |
|
|
from core.utils.utils import sv_intermediate_results |
|
|
|
|
|
|
|
|
|
|
|
class FusionDepth(nn.Module): |
|
|
def __init__(self, args, norm_fn='batch', ): |
|
|
super(FusionDepth, self).__init__() |
|
|
self.args = args |
|
|
self.norm_fn = norm_fn |
|
|
|
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d(3, 4, kernel_size=3, padding=1, bias=True), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=True), |
|
|
) |
|
|
self.down = nn.Sequential( |
|
|
ResidualBlock(4, 2*4, self.norm_fn, stride=2), |
|
|
ResidualBlock(2*4, 2*4, self.norm_fn, stride=1) |
|
|
) |
|
|
self.up = nn.ConvTranspose2d(2*4, 4, kernel_size=2, stride=2) |
|
|
self.conv2 = nn.Sequential( |
|
|
nn.Conv2d(8, 4, kernel_size=3, padding=1, bias=True), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(4, 1, kernel_size=3, padding=1, bias=True), |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, disp, depth, delta_disp): |
|
|
x = disp |
|
|
x1 = self.conv1( torch.cat([disp, depth, delta_disp], dim=1) ) |
|
|
|
|
|
x2 = self.up(self.down(x1)) |
|
|
|
|
|
x3 = self.conv2( torch.cat([x1,x2], dim=1) ) |
|
|
|
|
|
return x3 |
|
|
|
|
|
|
|
|
class UpdateHistory(nn.Module): |
|
|
def __init__(self, args, in_chans1, in_chans2): |
|
|
super(UpdateHistory, self).__init__() |
|
|
self.conv = nn.Conv2d(in_chans2, in_chans2, kernel_size=1, stride=1, padding=0) |
|
|
self.update = nn.Sequential(nn.Conv2d(in_chans1+in_chans2, in_chans1, kernel_size=3, stride=1, padding=1),) |
|
|
|
|
|
def forward(self, his, disp): |
|
|
hist_update = self.update( torch.cat([his,self.conv(disp)], dim=1) ) |
|
|
return hist_update |
|
|
|
|
|
|
|
|
class BetaModulator(nn.Module): |
|
|
def __init__(self, args, lbp_dim, hidden_dim=None, norm_fn='batch'): |
|
|
super(BetaModulator, self).__init__() |
|
|
self.norm_fn = norm_fn |
|
|
self.modulation_ratio = args.modulation_ratio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hidden_dim is None: |
|
|
hidden_dim = lbp_dim |
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d(lbp_dim*2, hidden_dim*2, kernel_size=3, padding=1, bias=True), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(hidden_dim*2, hidden_dim*2, kernel_size=3, padding=1, bias=True), |
|
|
) |
|
|
down_dim = 64 if hidden_dim*2<64 else 128 |
|
|
self.down = nn.Sequential( |
|
|
ResidualBlock(hidden_dim*2, down_dim, self.norm_fn, stride=2), |
|
|
ResidualBlock(down_dim, 128, self.norm_fn, stride=1) |
|
|
) |
|
|
self.up = nn.ConvTranspose2d(128, hidden_dim*2, kernel_size=2, stride=2) |
|
|
self.conv2 = nn.Sequential( |
|
|
nn.Conv2d(hidden_dim*4, hidden_dim, kernel_size=3, padding=1, bias=False), |
|
|
nn.Softplus(), |
|
|
nn.Conv2d(hidden_dim, 2, kernel_size=1, padding=0, bias=False), |
|
|
nn.Softplus(), |
|
|
) |
|
|
|
|
|
def forward(self, lbp_disp, lbp_depth, out_distribution=False): |
|
|
x1 = self.conv1( torch.cat([lbp_disp, lbp_depth], dim=1) ) |
|
|
x2 = self.up(self.down(x1)) |
|
|
beta_paras = self.conv2( torch.cat([x1,x2], dim=1) ) + 1 |
|
|
|
|
|
|
|
|
alpha, beta = torch.split(beta_paras, 1, dim=1) |
|
|
distribution = Beta(alpha, beta) |
|
|
|
|
|
if self.training: |
|
|
modulation = distribution.rsample() |
|
|
else: |
|
|
modulation = distribution.mean |
|
|
|
|
|
if not out_distribution: |
|
|
return modulation |
|
|
return modulation, distribution |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RefinementMonStereo(nn.Module): |
|
|
def __init__(self, args, norm_fn='batch', hidden_dim=32): |
|
|
super(RefinementMonStereo, self).__init__() |
|
|
self.args = args |
|
|
|
|
|
corr_channel = self.args.corr_levels * (self.args.corr_radius*2 + 1) |
|
|
if not args.conf_from_fea: |
|
|
conf_in_dim = corr_channel |
|
|
else: |
|
|
conf_in_dim = corr_channel + hidden_dim + 2 |
|
|
self.conf_estimate = nn.Sequential( |
|
|
nn.Conv2d(conf_in_dim, 128, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(128, 128, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(128, 1, 1, padding=0),) |
|
|
self.norm_conf = nn.Sigmoid() |
|
|
|
|
|
if self.args.refine_unet: |
|
|
self.mono_params_estimate = EfficientUNetSimple(num_classes=2) |
|
|
else: |
|
|
self.mono_params_estimate = nn.Sequential( |
|
|
nn.Conv2d(2, 32, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(32, 32, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(32, 2, 1, padding=0)) |
|
|
if self.args.refine_pool: |
|
|
self.mono_params_estimate.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1, 1))) |
|
|
|
|
|
factor = 2**self.args.n_downsample |
|
|
self.mask = nn.Sequential( |
|
|
nn.Conv2d(hidden_dim+1, 256, 3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, (factor**2)*9, 1, padding=0)) |
|
|
|
|
|
def forward(self, disp, depth, hidden, cost_volume, Beta_distribution=None): |
|
|
if not self.args.conf_from_fea: |
|
|
conf = self.conf_estimate(cost_volume) |
|
|
else: |
|
|
conf = self.conf_estimate( torch.cat([cost_volume,hidden,Beta_distribution.mean,Beta_distribution.variance], dim=1) ) |
|
|
conf_normed = self.norm_conf(conf) |
|
|
|
|
|
mono_params = self.mono_params_estimate( torch.cat([disp, depth], dim=1) ) |
|
|
a, b = torch.split(mono_params, 1, dim=1) |
|
|
depth_registered = depth * a + b |
|
|
|
|
|
disp = disp * conf_normed + (1-conf_normed) * depth_registered |
|
|
|
|
|
up_mask= self.mask( torch.cat([hidden, disp], dim=1) ) |
|
|
|
|
|
if self.args is not None and hasattr(self.args, "vis_inter") and self.args.vis_inter: |
|
|
sv_intermediate_results(disp, f"disp_refine", self.args.sv_root) |
|
|
sv_intermediate_results(depth_registered, f"depth_registered", self.args.sv_root) |
|
|
sv_intermediate_results(conf_normed, f"conf", self.args.sv_root) |
|
|
sv_intermediate_results(a, f"a", self.args.sv_root) |
|
|
sv_intermediate_results(b, f"b", self.args.sv_root) |
|
|
|
|
|
return disp, up_mask, depth_registered, conf |