import torch import torch.nn as nn import torch.nn.functional as F def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape * 2 out_shape3 = out_shape * 4 if len(in_shape) >= 4: out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) return scratch class ResidualConvUnit(nn.Module): """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups = 1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups, ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups, ) if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn == True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn == True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) class FeatureFusionBlock(nn.Module): """Feature fusion block.""" def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None, ): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups = 1 self.expand = expand out_features = features if self.expand == True: out_features = features // 2 self.out_conv = nn.Conv2d( features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1, ) self.resConfUnit1 = ResidualConvUnit(features, activation, bn) self.resConfUnit2 = ResidualConvUnit(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() self.size = size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = nn.functional.interpolate( output, **modifier, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output def _make_fusion_block(features, use_bn, size=None): return FeatureFusionBlock( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size, ) class DPTHead(nn.Module): def __init__( self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False, concat_cnn_features=True, concat_mv_features=True, cnn_feature_channels=[64, 96, 128], concat_features=True, downsample_factor=8, return_feature=False, num_scales=1, latent_downsample=None, latent_feature_no_concat=False, ): super(DPTHead, self).__init__() self.use_clstoken = use_clstoken self.concat_cnn_features = concat_cnn_features self.concat_mv_features = concat_mv_features self.concat_features = concat_features self.downsample_factor = downsample_factor self.return_feature = return_feature self.num_scales = num_scales self.latent_downsample = latent_downsample self.latent_feature_no_concat = latent_feature_no_concat if self.concat_features: if self.downsample_factor == 4 and num_scales == 2: depth_channel = 0 if self.return_feature else 1 self.concat_projects = nn.ModuleList( [ nn.Conv2d( cnn_feature_channels[0] + out_channels[0], out_channels[0], 1, ), nn.Conv2d( cnn_feature_channels[1] + out_channels[1] + 64 + depth_channel, out_channels[1], 1, ), # 1/4 concat(cnn, mono, mv, depth) nn.Conv2d( cnn_feature_channels[2] + out_channels[2] + 128, out_channels[2], 1, ), # 1/8 concat(cnn, mono, mv) ] ) elif self.downsample_factor == 2 and num_scales == 2: depth_channel = 0 if self.return_feature else 1 self.concat_projects = nn.ModuleList( [ nn.Conv2d( cnn_feature_channels[0] + cnn_feature_channels[1] + out_channels[0] + 64 + depth_channel, out_channels[0], 1, ), # 1/2 nn.Conv2d( cnn_feature_channels[2] + out_channels[1] + 128, out_channels[1], 1, ), # 1/4 concat(cnn, mono, mv, depth) nn.Conv2d(out_channels[2], out_channels[2], 1), # 1/8 mono ] ) elif self.downsample_factor == 4 and num_scales == 1: depth_channel = 0 if self.return_feature else 1 self.concat_projects = nn.ModuleList( [ nn.Conv2d( cnn_feature_channels[0] + cnn_feature_channels[1] + out_channels[0], out_channels[0], 1, ), nn.Conv2d( cnn_feature_channels[2] + out_channels[1] + 128 + depth_channel, out_channels[1], 1, ), nn.Conv2d(out_channels[2], out_channels[2], 1), # 1/8 mono ] ) else: depth_channel = 0 if self.return_feature else 1 self.concat_projects = nn.ModuleList( [ nn.Conv2d( cnn_feature_channels[0] + out_channels[0], out_channels[0], 1, ), nn.Conv2d( cnn_feature_channels[1] + out_channels[1], out_channels[1], 1, ), nn.Conv2d( cnn_feature_channels[2] + out_channels[2] + 128 + depth_channel, out_channels[2], 1, ), # 1/8 concat(cnn, mono, mv, depth) ] ) else: if self.concat_cnn_features: self.cnn_projects = nn.ModuleList( [ nn.Conv2d(cnn_feature_channels[i], out_channels[i], 1) for i in range(len(cnn_feature_channels)) ] ) if self.concat_mv_features: self.mv_projects = nn.Conv2d(128, out_channels[2], 1) self.projects = nn.ModuleList( [ nn.Conv2d( in_channels=in_channels, out_channels=out_channel, kernel_size=1, stride=1, padding=0, ) for out_channel in out_channels ] ) self.resize_layers = nn.ModuleList( [ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0, ), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0, ), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1, ), ] ) if use_clstoken: self.readout_projects = nn.ModuleList() for _ in range(len(self.projects)): self.readout_projects.append( nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()) ) self.scratch = _make_scratch( out_channels, features, groups=1, expand=False, ) self.scratch.stem_transpose = None if not self.latent_feature_no_concat: self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) if self.latent_downsample != 8: self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) # not used del self.scratch.refinenet4.resConfUnit1 head_features_1 = features head_features_2 = 16 if not self.return_feature: self.scratch.output_conv = nn.Sequential( nn.Conv2d( head_features_1, head_features_1 // 2, 3, 1, 1, padding_mode="replicate", ), nn.GELU(), nn.Conv2d( head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, padding_mode="replicate", ), nn.GELU(), nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), ) # init delta depth as zero nn.init.zeros_(self.scratch.output_conv[-1].weight) nn.init.zeros_(self.scratch.output_conv[-1].bias) def forward( self, out_features, downsample_factor=8, cnn_features=None, mv_features=None, depth=None, ): out = [] for i, x in enumerate(out_features): x = self.projects[i](x) x = self.resize_layers[i](x) out.append(x) # 1/2, 1/4, 1/8, 1/16 layer_1, layer_2, layer_3, layer_4 = out if self.concat_features: if not self.return_feature: assert depth is not None if self.downsample_factor == 4 and self.num_scales == 1: concat1 = torch.cat((cnn_features[0], cnn_features[1], layer_1), dim=1) elif self.downsample_factor == 2 and self.num_scales == 2: if self.return_feature: concat1 = torch.cat( (cnn_features[0], cnn_features[1], mv_features[0], layer_1), dim=1, ) else: concat1 = torch.cat( ( cnn_features[0], cnn_features[1], mv_features[0], depth, layer_1, ), dim=1, ) else: concat1 = torch.cat((cnn_features[0], layer_1), dim=1) layer_1 = self.concat_projects[0](concat1) # 1/2 if self.downsample_factor == 4 and self.num_scales == 2: assert isinstance(mv_features, list) if self.return_feature: concat2 = torch.cat( (cnn_features[1], layer_2, mv_features[0]), dim=1 ) else: concat2 = torch.cat( (cnn_features[1], layer_2, mv_features[0], depth), dim=1 ) layer_2 = self.concat_projects[1](concat2) # 1/4 concat3 = torch.cat((cnn_features[2], layer_3, mv_features[1]), dim=1) layer_3 = self.concat_projects[2](concat3) # 1/8 elif self.downsample_factor == 2 and self.num_scales == 2: assert isinstance(mv_features, list) concat2 = torch.cat((cnn_features[2], layer_2, mv_features[1]), dim=1) layer_2 = self.concat_projects[1](concat2) # 1/4 concat3 = layer_3 layer_3 = self.concat_projects[2](concat3) # 1/8 elif self.downsample_factor == 4 and self.num_scales == 1: if self.return_feature: concat2 = torch.cat((cnn_features[2], layer_2, mv_features), dim=1) else: concat2 = torch.cat( (cnn_features[2], layer_2, mv_features, depth), dim=1 ) layer_2 = self.concat_projects[1](concat2) # 1/4 concat3 = layer_3 layer_3 = self.concat_projects[2](concat3) # 1/8 else: concat2 = torch.cat((cnn_features[1], layer_2), dim=1) layer_2 = self.concat_projects[1](concat2) # 1/4 if self.return_feature: concat3 = torch.cat((cnn_features[2], layer_3, mv_features), dim=1) else: concat3 = torch.cat( (cnn_features[2], layer_3, mv_features, depth), dim=1 ) layer_3 = self.concat_projects[2](concat3) # 1/8 else: if self.concat_cnn_features: assert cnn_features is not None assert len(cnn_features) == 3 # 1/2, 1/4, 1/8 cnn_features = [ self.cnn_projects[i](f) for i, f in enumerate(cnn_features) ] layer_1 = layer_1 + cnn_features[0] # 1/2 layer_2 = layer_2 + cnn_features[1] # 1/4 layer_3 = layer_3 + cnn_features[2] # 1/8 if self.concat_mv_features: # 1/8 mv_features = self.mv_projects(mv_features) layer_3 = layer_3 + mv_features # 1/8 layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) # 1/8 if self.latent_feature_no_concat and self.latent_downsample == 8 and self.return_feature: return path_4 path_3 = self.scratch.refinenet3( path_4, layer_3_rn, size=layer_2_rn.shape[2:] ) # 1/4 if self.latent_feature_no_concat and self.latent_downsample == 4 and self.return_feature: return path_3 path_2 = self.scratch.refinenet2( path_3, layer_2_rn, size=layer_1_rn.shape[2:] ) # 1/2 path_1 = self.scratch.refinenet1(path_2, layer_1_rn) # 1 if self.latent_downsample == 4: # all resize to 1/4 resolution path_4 = F.interpolate(path_4, scale_factor=2, mode='bilinear', align_corners=True) path_2 = F.interpolate(path_2, scale_factor=0.5, mode='bilinear', align_corners=True) path_1 = F.interpolate(path_1, scale_factor=0.25, mode='bilinear', align_corners=True) # concat all path_1 = torch.cat((path_4, path_3, path_2, path_1), dim=1) if self.return_feature: return path_1 out = self.scratch.output_conv(path_1) return out if __name__ == "__main__": device = torch.device("cuda") c = 384 model = DPTHead( in_channels=c, concat_cnn_features=True, concat_mv_features=True, ).to(device) print(model) h, w = 16, 32 x = torch.randn(2, c, h, w).to(device) out_features = [x] * 4 cnn_features = [ torch.randn(2, 64, h * 4, w * 4).to(device), torch.randn(2, 96, h * 2, w * 2).to(device), torch.randn(2, 128, h, w).to(device), ] mv_features = torch.randn(2, 128, h, w).to(device) out = model(out_features, h, w, cnn_features=cnn_features, mv_features=mv_features) print(out.shape)