import torch import torch.nn as nn import torch.nn.functional as F from .backbone import CNNEncoder from .vit_fpn import ViTFeaturePyramid from .mv_transformer import ( MultiViewFeatureTransformer, batch_features_camera_parameters, ) from .matching import warp_with_pose_depth_candidates from .utils import mv_feature_add_position from .dpt_head import DPTHead from .ldm_unet.unet import UNetModel, AttentionBlock from einops import rearrange from .dinov2.dinov2 import DINOv2 class MultiViewUniMatch(nn.Module): def __init__( self, num_scales=1, feature_channels=128, upsample_factor=8, lowest_feature_resolution=8, num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, num_depth_candidates=128, vit_type="vits", unet_channels=128, unet_channel_mult=[1, 1, 1], unet_num_res_blocks=1, unet_attn_resolutions=[4], grid_sample_disable_cudnn=False, only_features=False, sample_log_depth=False, bilinear_upsample_depth=False, no_upsample_depth=False, use_amp=False, return_raw_mono_features=False, max_mono_vit_input_size=560, # constrain the input resolution to vit use_checkpointing=False, **kwargs, ): super(MultiViewUniMatch, self).__init__() # CNN self.feature_channels = feature_channels self.num_scales = num_scales self.lowest_feature_resolution = lowest_feature_resolution self.upsample_factor = upsample_factor self.only_features = only_features self.bilinear_upsample_depth = bilinear_upsample_depth self.no_upsample_depth = no_upsample_depth self.return_raw_mono_features = return_raw_mono_features self.max_mono_vit_input_size = max_mono_vit_input_size self.use_amp = use_amp # sample depth in the log scale instead of the inverse depth self.sample_log_depth = sample_log_depth # monocular backbones: final self.vit_type = vit_type # cost volume self.num_depth_candidates = num_depth_candidates # upsampler vit_feature_channel_dict = {"vits": 384, "vitb": 768, "vitl": 1024} vit_feature_channel = vit_feature_channel_dict[vit_type] # CNN self.backbone = CNNEncoder( output_dim=feature_channels, num_output_scales=num_scales, downsample_factor=upsample_factor, lowest_scale=lowest_feature_resolution, return_all_scales=True, ) # Transformer self.transformer = MultiViewFeatureTransformer( num_layers=num_transformer_layers, d_model=feature_channels, nhead=num_head, ffn_dim_expansion=ffn_dim_expansion, use_checkpointing=use_checkpointing, ) if self.num_scales > 1: # generate multi-scale features self.mv_pyramid = ViTFeaturePyramid( in_channels=128, scale_factors=[2**i for i in range(self.num_scales)] ) # monodepth encoder = vit_type # local load dinov2 self.pretrained = DINOv2(encoder, use_checkpointing=use_checkpointing, ) # self.pretrained = torch.hub.load( # "facebookresearch/dinov2", "dinov2_{:}14".format(encoder) # ) del self.pretrained.mask_token # unused if self.num_scales > 1: # generate multi-scale features self.mono_pyramid = ViTFeaturePyramid( in_channels=vit_feature_channel, scale_factors=[2**i for i in range(self.num_scales)], ) if self.only_features: return # UNet regressor self.regressor = nn.ModuleList() self.regressor_residual = nn.ModuleList() self.depth_head = nn.ModuleList() for i in range(self.num_scales): curr_depth_candidates = num_depth_candidates // (4**i) cnn_feature_channels = 128 - (32 * i) mv_transformer_feature_channels = 128 // (2**i) mono_feature_channels = vit_feature_channel // (2**i) # concat(cost volume, cnn feature, mv feature, mono feature) in_channels = ( curr_depth_candidates + cnn_feature_channels + mv_transformer_feature_channels + mono_feature_channels ) # unet channels channels = unet_channels // (2**i) # unet channel mult & unet_attn_resolutions if i > 0: unet_channel_mult = unet_channel_mult + [1] unet_attn_resolutions = [x * 2 for x in unet_attn_resolutions] # unet modules = [ nn.Conv2d(in_channels, channels, 3, 1, 1), nn.GroupNorm(8, channels), nn.GELU(), ] modules.append( UNetModel( image_size=None, in_channels=channels, model_channels=channels, out_channels=channels, num_res_blocks=unet_num_res_blocks, attention_resolutions=unet_attn_resolutions, channel_mult=unet_channel_mult, num_head_channels=32, dims=2, postnorm=False, num_frames=2, use_cross_view_self_attn=True, ) ) modules.append(nn.Conv2d(channels, channels, 3, 1, 1)) self.regressor.append(nn.Sequential(*modules)) # regressor residual self.regressor_residual.append(nn.Conv2d(in_channels, channels, 1)) # depth head self.depth_head.append( nn.Sequential( nn.Conv2d( channels, channels * 2, 3, 1, 1, padding_mode="replicate" ), nn.GELU(), nn.Conv2d( channels * 2, curr_depth_candidates, 3, 1, 1, padding_mode="replicate", ), ) ) # upsampler # concat(lowres_depth, cnn feature, mv feature, mono feature) in_channels = ( 1 + cnn_feature_channels + mv_transformer_feature_channels + mono_feature_channels ) model_configs = { "vits": { "in_channels": 384, "features": 32, "out_channels": [48, 96, 192, 384], }, "vitb": { "in_channels": 768, "features": 48, "out_channels": [96, 192, 384, 768], }, "vitl": { "in_channels": 1024, "features": 64, "out_channels": [128, 256, 512, 1024], }, } if not self.bilinear_upsample_depth and not self.no_upsample_depth: self.upsampler = DPTHead( **model_configs[vit_type], downsample_factor=upsample_factor, num_scales=num_scales, ) self.grid_sample_disable_cudnn = grid_sample_disable_cudnn def normalize_images(self, images): """Normalize image to match the pretrained UniMatch model. images: (B, V, C, H, W) """ shape = [*[1] * (images.dim() - 3), 3, 1, 1] mean = torch.tensor([0.485, 0.456, 0.406]).reshape(*shape).to(images.device) std = torch.tensor([0.229, 0.224, 0.225]).reshape(*shape).to(images.device) return (images - mean) / std def extract_feature(self, images): # images: [B, V, C, H, W] b, v = images.shape[:2] concat = rearrange(images, "b v c h w -> (b v) c h w") # list of [BV, C, H, W], resolution from high to low features = self.backbone(concat) # reverse: resolution from low to high features = features[::-1] return features def forward( self, images, attn_splits_list=None, intrinsics=None, min_depth=1.0 / 0.5, # inverse depth range max_depth=1.0 / 100, num_depth_candidates=128, extrinsics=None, nn_matrix=None, **kwargs, ): results_dict = {} depth_preds = [] match_probs = [] # first normalize images images = self.normalize_images(images) b, v, _, ori_h, ori_w = images.shape # update the num_views in unet attention, useful for random input views if not self.only_features: set_num_views(self.regressor, num_views=v) # NOTE: in this codebase, intrinsics are normalized by image width and height # in unimatch's codebase: https://github.com/autonomousvision/unimatch, no normalization intrinsics = intrinsics.clone() intrinsics[:, :, 0] *= ori_w intrinsics[:, :, 1] *= ori_h # max_depth, min_depth: [B, V] -> [BV] max_depth = max_depth.view(-1) min_depth = min_depth.view(-1) if self.sample_log_depth: # inverse depth to depth min_depth, max_depth = 1. / max_depth, 1. / min_depth min_depth, max_depth = torch.log(min_depth), torch.log(max_depth) # list of features, resolution low to high # list of [BV, C, H, W] with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): features_list_cnn = self.extract_feature(images) features_list_cnn_all_scales = features_list_cnn features_list_cnn = features_list_cnn[: self.num_scales] results_dict.update({"features_cnn_all_scales": features_list_cnn_all_scales}) results_dict.update({"features_cnn": features_list_cnn}) # mv transformer features # add position to features attn_splits = attn_splits_list[0] # [BV, C, H, W] features_cnn_pos = mv_feature_add_position( features_list_cnn[0], attn_splits, self.feature_channels ) # list of [B, C, H, W] features_list = list( torch.unbind( rearrange(features_cnn_pos, "(b v) c h w -> b v c h w", b=b, v=v), dim=1 ) ) with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): if features_list[0].shape[-1] > 96: attn_splits = 4 if features_list[0].shape[-1] > 192: attn_splits = 8 features_list_mv = self.transformer( features_list, attn_num_splits=attn_splits, nn_matrix=nn_matrix, ) features_mv = rearrange( torch.stack(features_list_mv, dim=1), "b v c h w -> (b v) c h w" ) # [BV, C, H, W] if self.num_scales > 1: # multi-scale mv features: resolution from low to high # list of [BV, C, H, W] with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): features_list_mv = self.mv_pyramid(features_mv) else: features_list_mv = [features_mv] results_dict.update({"features_mv": features_list_mv}) # mono feature ori_h, ori_w = images.shape[-2:] # TODO: support portrait images later assert ori_h <= ori_w if ori_w > self.max_mono_vit_input_size: resize_w = self.max_mono_vit_input_size // 14 * 14 resize_h = int((ori_h / ori_w) * self.max_mono_vit_input_size) // 14 * 14 else: resize_h, resize_w = ori_h // 14 * 14, ori_w // 14 * 14 # print(resize_h, resize_w) concat = rearrange(images, "b v c h w -> (b v) c h w") concat = F.interpolate( concat, (resize_h, resize_w), mode="bilinear", align_corners=True ) # get intermediate features intermediate_layer_idx = { "vits": [2, 5, 8, 11], "vitb": [2, 5, 8, 11], "vitl": [4, 11, 17, 23], } with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): mono_intermediate_features = list( self.pretrained.get_intermediate_layers( concat, intermediate_layer_idx[self.vit_type], return_class_token=False ) ) if self.return_raw_mono_features: raw_mono_features = [] for i in range(len(mono_intermediate_features)): curr_features = ( mono_intermediate_features[i] .reshape(concat.shape[0], resize_h // 14, resize_w // 14, -1) .permute(0, 3, 1, 2) .contiguous() ) if self.return_raw_mono_features: raw_mono_features.append(curr_features) # resize to 1/8 resolution curr_features = F.interpolate( curr_features, (ori_h // 8, ori_w // 8), mode="bilinear", align_corners=True, ) mono_intermediate_features[i] = curr_features results_dict.update({"features_mono_intermediate": mono_intermediate_features}) if self.return_raw_mono_features: results_dict.update({"raw_mono_features": raw_mono_features}) # last mono feature # TODO: use all the intermediate features for depth estimation? mono_features = mono_intermediate_features[-1] if self.lowest_feature_resolution == 4: mono_features = F.interpolate( mono_features, scale_factor=2, mode="bilinear", align_corners=True ) if self.num_scales > 1: # multi-scale mono features, resolution from low to high # list of [BV, C, H, W] with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): features_list_mono = self.mono_pyramid(mono_features) else: features_list_mono = [mono_features] results_dict.update({"features_mono": features_list_mono}) if self.only_features: return results_dict depth = None for scale_idx in range(self.num_scales): downsample_factor = self.upsample_factor * ( 2 ** (self.num_scales - 1 - scale_idx) ) # scale intrinsics intrinsics_curr = intrinsics.clone() # [B, V, 3, 3] intrinsics_curr[:, :, :2] = intrinsics_curr[:, :, :2] / downsample_factor # build cost volume features_mv = features_list_mv[scale_idx] # [BV, C, H, W] # list of [B, C, H, W] features_mv_curr = list( torch.unbind( rearrange(features_mv, "(b v) c h w -> b v c h w", b=b, v=v), dim=1 ) ) intrinsics_curr = list( torch.unbind(intrinsics_curr, dim=1) ) # list of [B, 3, 3] extrinsics_curr = list(torch.unbind(extrinsics, dim=1)) # list of [B, 4, 4] # ref: [BV, C, H, W], [BV, 3, 3], [BV, 4, 4] # tgt: [BV, V-1, C, H, W], [BV, V-1, 3, 3], [BV, V-1, 4, 4] ( ref_features, ref_intrinsics, ref_extrinsics, tgt_features, tgt_intrinsics, tgt_extrinsics, ) = batch_features_camera_parameters( features_mv_curr, intrinsics_curr, extrinsics_curr, nn_matrix=nn_matrix, ) b_new, _, c, h, w = tgt_features.size() # relative pose # extrinsics: c2w pose_curr = torch.matmul( tgt_extrinsics.inverse(), ref_extrinsics.unsqueeze(1) ) # [BV, V-1, 4, 4] if scale_idx > 0: # 2x upsample depth assert depth is not None depth = F.interpolate( depth, scale_factor=2, mode="bilinear", align_corners=True ).detach() num_depth_candidates = self.num_depth_candidates // (4**scale_idx) # generate depth candidates if scale_idx == 0: # min_depth, max_depth: [BV] depth_interval = (max_depth - min_depth) / ( self.num_depth_candidates - 1 ) # [BV] linear_space = ( torch.linspace(0, 1, num_depth_candidates) .type_as(features_list_cnn[0]) .view(1, num_depth_candidates, 1, 1) ) # [1, D, 1, 1] depth_candidates = min_depth.view(-1, 1, 1, 1) + linear_space * ( max_depth - min_depth ).view( -1, 1, 1, 1 ) # [BV, D, 1, 1] else: # half interval each scale depth_interval = ( (max_depth - min_depth) / (self.num_depth_candidates - 1) / (2**scale_idx) ) # [BV] # [BV, 1, 1, 1] depth_interval = depth_interval.view(-1, 1, 1, 1) # [BV, 1, H, W] depth_range_min = ( depth - depth_interval * (num_depth_candidates // 2) ).clamp(min=min_depth.view(-1, 1, 1, 1)) depth_range_max = ( depth + depth_interval * (num_depth_candidates // 2 - 1) ).clamp(max=max_depth.view(-1, 1, 1, 1)) linear_space = ( torch.linspace(0, 1, num_depth_candidates) .type_as(features_list_cnn[0]) .view(1, num_depth_candidates, 1, 1) ) # [1, D, 1, 1] depth_candidates = depth_range_min + linear_space * ( depth_range_max - depth_range_min ) # [BV, D, H, W] if scale_idx == 0: # [BV*(V-1), D, H, W] depth_candidates_curr = ( depth_candidates.unsqueeze(1) .repeat(1, tgt_features.size(1), 1, h, w) .view(-1, num_depth_candidates, h, w) ) else: depth_candidates_curr = ( depth_candidates.unsqueeze(1) .repeat(1, tgt_features.size(1), 1, 1, 1) .view(-1, num_depth_candidates, h, w) ) intrinsics_input = torch.stack(intrinsics_curr, dim=1).view( -1, 3, 3 ) # [BV, 3, 3] intrinsics_input = intrinsics_input.unsqueeze(1).repeat( 1, tgt_features.size(1), 1, 1 ) # [BV, V-1, 3, 3] ref_features = ref_features.float() tgt_features = tgt_features.float() depth_candidates_curr = depth_candidates_curr.float() warped_tgt_features = warp_with_pose_depth_candidates( rearrange(tgt_features, "b v ... -> (b v) ..."), rearrange(intrinsics_input, "b v ... -> (b v) ..."), rearrange(pose_curr, "b v ... -> (b v) ..."), torch.exp(depth_candidates_curr) if self.sample_log_depth else 1.0 / depth_candidates_curr, # convert inverse/log depth to depth grid_sample_disable_cudnn=self.grid_sample_disable_cudnn, ) # [BV*(V-1), C, D, H, W] # ref: [BV, C, H, W] # warped: [BV*(V-1), C, D, H, W] -> [BV, V-1, C, D, H, W] warped_tgt_features = rearrange( warped_tgt_features, "(b v) ... -> b v ...", b=b_new, v=tgt_features.size(1), ) # [BV, V-1, D, H, W] -> [BV, D, H, W] # average cross other views cost_volume = ( (ref_features.unsqueeze(-3).unsqueeze(1) * warped_tgt_features).sum(2) / (c**0.5) ).mean(1) # regressor features_cnn = features_list_cnn[scale_idx] # [BV, C, H, W] features_mono = features_list_mono[scale_idx] # [BV, C, H, W] concat = torch.cat( (cost_volume, features_cnn, features_mv, features_mono), dim=1 ) with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): out = self.regressor[scale_idx](concat) + self.regressor_residual[ scale_idx ](concat) out = out.float() # depth pred match_prob = F.softmax( self.depth_head[scale_idx](out), dim=1 ) # [BV, D, H, W] match_probs.append(match_prob) if scale_idx == 0: # [BV, D, H, W] depth_candidates = depth_candidates.repeat(1, 1, h, w) depth = (match_prob * depth_candidates).sum( dim=1, keepdim=True ) # [BV, 1, H, W] # upsample to the original resolution for supervison at training time only if self.training and scale_idx < self.num_scales - 1: depth_bilinear = F.interpolate( depth, scale_factor=downsample_factor, mode="bilinear", align_corners=True, ) depth_preds.append(depth_bilinear) # final output, learned upsampler if scale_idx == self.num_scales - 1: if self.bilinear_upsample_depth or self.no_upsample_depth: residual_depth = 0 else: with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): residual_depth = self.upsampler( mono_intermediate_features, # resolution high to low cnn_features=features_list_cnn_all_scales[::-1], mv_features=( features_mv if self.num_scales == 1 else features_list_mv[::-1] ), depth=depth, ) if self.no_upsample_depth: depth_preds.append(depth) else: depth_bilinear = F.interpolate( depth, scale_factor=self.upsample_factor, mode="bilinear", align_corners=True, ) depth = (depth_bilinear + residual_depth).clamp( min=min_depth.view(-1, 1, 1, 1), max=max_depth.view(-1, 1, 1, 1) ) depth_preds.append(depth) for i in range(len(depth_preds)): if self.sample_log_depth: # log depth to depth depth_pred = torch.exp(depth_preds[i].squeeze(1)) else: # convert inverse depth to depth depth_pred = 1.0 / depth_preds[i].squeeze(1) # [BV, H, W] depth_preds[i] = rearrange( depth_pred, "(b v) ... -> b v ...", b=b, v=v ) # [B, V, H, W] results_dict.update({"depth_preds": depth_preds}) results_dict.update({"match_probs": match_probs}) return results_dict def set_num_views(module, num_views): if isinstance(module, AttentionBlock): module.attention.n_frames = num_views elif ( isinstance(module, nn.ModuleList) or isinstance(module, nn.Sequential) or isinstance(module, nn.Module) ): for submodule in module.children(): set_num_views(submodule, num_views)