| from dataclasses import dataclass |
| from typing import Literal, Optional, List |
|
|
| import torch |
| from einops import rearrange, repeat |
| from jaxtyping import Float |
| from torch import Tensor, nn |
| import MinkowskiEngine as ME |
|
|
| from ...dataset.shims.patch_shim import apply_patch_shim |
| from ...dataset.types import BatchedExample, DataShim |
| from ...geometry.projection import sample_image_grid |
| from ..types import Gaussians |
|
|
| |
| from .common.gaussian_adapter_revise import GaussianAdapter_revise, GaussianAdapterCfg |
| |
|
|
|
|
| from .encoder import Encoder |
| from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg |
|
|
| import torchvision.transforms as T |
| import torch.nn.functional as F |
|
|
| from .unimatch.mv_unimatch import MultiViewUniMatch |
| from .unimatch.dpt_head import DPTHead |
|
|
| from .common.voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_features_to_voxel, adapte_project_features_to_3d |
| from .common.me_fea import project_features_to_me |
|
|
| from ...geometry.projection import get_world_rays |
| from .common.sparse_net import SparseGaussianHead |
|
|
| from ...test.export_ply import save_point_cloud_to_ply |
|
|
|
|
| |
| from ...test.try_depthanything import DepthAnythingWrapper |
| from ...test.visual import save_depth_images, save_output_images |
|
|
|
|
| @dataclass |
| class EncoderDepthSplatCfg: |
| name: Literal["depthsplat"] |
| d_feature: int |
| num_depth_candidates: int |
| num_surfaces: int |
| visualizer: EncoderVisualizerDepthSplatCfg |
| gaussian_adapter: GaussianAdapterCfg |
| gaussians_per_pixel: int |
| unimatch_weights_path: str | None |
| downscale_factor: int |
| shim_patch_size: int |
| multiview_trans_attn_split: int |
| costvolume_unet_feat_dim: int |
| costvolume_unet_channel_mult: List[int] |
| costvolume_unet_attn_res: List[int] |
| depth_unet_feat_dim: int |
| depth_unet_attn_res: List[int] |
| depth_unet_channel_mult: List[int] |
|
|
| |
| num_scales: int |
| upsample_factor: int |
| lowest_feature_resolution: int |
| depth_unet_channels: int |
| grid_sample_disable_cudnn: bool |
|
|
| |
| large_gaussian_head: bool |
| color_large_unet: bool |
| init_sh_input_img: bool |
| feature_upsampler_channels: int |
| gaussian_regressor_channels: int |
|
|
| |
| supervise_intermediate_depth: bool |
| return_depth: bool |
|
|
| |
| train_depth_only: bool |
|
|
| |
| monodepth_vit_type: str |
|
|
| |
| local_mv_match: int |
|
|
|
|
| class EncoderDepthSplat_test(Encoder[EncoderDepthSplatCfg]): |
| def __init__(self, cfg: EncoderDepthSplatCfg) -> None: |
| super().__init__(cfg) |
|
|
| self.depth_predictor = MultiViewUniMatch( |
| num_scales=cfg.num_scales, |
| upsample_factor=cfg.upsample_factor, |
| lowest_feature_resolution=cfg.lowest_feature_resolution, |
| vit_type=cfg.monodepth_vit_type, |
| unet_channels=cfg.depth_unet_channels, |
| grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, |
| ) |
|
|
| if self.cfg.train_depth_only: |
| return |
|
|
| |
| model_configs = { |
| 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
| 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, |
| 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, |
| } |
|
|
| self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], |
| downsample_factor=cfg.upsample_factor, |
| return_feature=True, |
| num_scales=cfg.num_scales, |
| ) |
| feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] |
| |
| |
| self.gaussian_adapter = GaussianAdapter_revise(cfg.gaussian_adapter) |
|
|
| |
| in_channels = 3 + 1 + 1 + feature_upsampler_channels |
| channels = self.cfg.gaussian_regressor_channels |
|
|
| |
| modules = [ |
| nn.Conv2d(in_channels, channels, 3, 1, 1), |
| nn.GELU(), |
| nn.Conv2d(channels, channels, 3, 1, 1), |
| ] |
|
|
| self.gaussian_regressor = nn.Sequential(*modules) |
|
|
| |
| |
| num_gaussian_parameters = self.gaussian_adapter.d_in + 3 + 1 |
|
|
| |
| in_channels = 3 + feature_upsampler_channels + channels + 1 |
| |
| |
| |
| self.gaussian_head = SparseGaussianHead(in_channels, num_gaussian_parameters) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| encoder = 'vitb' |
| checkpoint_path = '/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/pretrained/depth_anything_vitb14.pth' |
| self.depth_anything = DepthAnythingWrapper(encoder,checkpoint_path) |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| def forward( |
| self, |
| context: dict, |
| global_step: int, |
| deterministic: bool = False, |
| visualization_dump: Optional[dict] = None, |
| scene_names: Optional[list] = None, |
| ): |
| device = context["image"].device |
| b, v, _, h, w = context["image"].shape |
|
|
| if v > 3: |
| with torch.no_grad(): |
| xyzs = context["extrinsics"][:, :, :3, -1].detach() |
| cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) |
| cameras_dist_index = torch.argsort(cameras_dist_matrix) |
|
|
| cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] |
| else: |
| cameras_dist_index = None |
|
|
| |
| results_dict = self.depth_predictor( |
| context["image"], |
| attn_splits_list=[2], |
| min_depth=1. / context["far"], |
| max_depth=1. / context["near"], |
| intrinsics=context["intrinsics"], |
| extrinsics=context["extrinsics"], |
| nn_matrix=cameras_dist_index, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| depth_preds = results_dict['depth_preds'] |
| |
|
|
| |
| depth = depth_preds[-1] |
| |
| |
|
|
| |
| if self.cfg.train_depth_only: |
| |
| |
| depths = rearrange(depth, "b v h w -> b v (h w) () ()") |
|
|
| if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
| |
| num_depths = len(depth_preds) |
|
|
| |
| intermediate_depths = torch.cat( |
| depth_preds[:(num_depths - 1)], dim=0) |
| intermediate_depths = rearrange( |
| intermediate_depths, "b v h w -> b v (h w) () ()") |
|
|
| |
| depths = torch.cat((intermediate_depths, depths), dim=0) |
|
|
| b *= num_depths |
|
|
| |
| depths = rearrange( |
| depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w |
| ).squeeze(-1).squeeze(-1) |
| |
|
|
| return { |
| "gaussians": None, |
| "depths": depths |
| } |
|
|
| |
| features = self.feature_upsampler(results_dict["features_mono_intermediate"], |
| cnn_features=results_dict["features_cnn_all_scales"][::-1], |
| mv_features=results_dict["features_mv"][ |
| 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] |
| ) |
| |
| |
| |
| match_prob = results_dict['match_probs'][-1] |
| match_prob = torch.max(match_prob, dim=1, keepdim=True)[ |
| 0] |
| match_prob = F.interpolate( |
| match_prob, size=depth.shape[-2:], mode='nearest') |
| |
| |
| |
| concat = torch.cat(( |
| rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| rearrange(depth, "b v h w -> (b v) () h w"), |
| match_prob, |
| features, |
| ), dim=1) |
| |
| out = self.gaussian_regressor(concat) |
| concat = [out, |
| rearrange(context["image"], |
| "b v c h w -> (b v) c h w"), |
| features, |
| match_prob] |
| |
| out = torch.cat(concat, dim=1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| voxel_feature, media_val, scale_factor, xbound, ybound, zbound = adapte_project_features_to_3d( |
| context["intrinsics"], |
| context["extrinsics"], |
| out, |
| depth=depth, |
| b=b, v=v, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| sparse_input = project_features_to_me( |
| context["intrinsics"], |
| context["extrinsics"], |
| out, |
| depth=depth, |
| b=b, v=v |
| ) |
| |
| |
| |
| |
| gaussians = self.gaussian_head(sparse_input) |
| |
| |
| |
| |
| |
| |
| gaussian_params = gaussians.F.unsqueeze(0).unsqueeze(0) |
| |
| |
| |
| |
| opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
| raw_gaussians = gaussian_params[..., 1:] |
| raw_gaussians = rearrange( |
| raw_gaussians, |
| "... (srf c) -> ... srf c", |
| srf=self.cfg.num_surfaces, |
| ) |
| |
| |
| try: |
| |
| gaussians = self.gaussian_adapter.forward( |
| extrinsics = context["extrinsics"], |
| intrinsics = context["intrinsics"], |
| opacities = opacities, |
| raw_gaussians = rearrange(raw_gaussians,"b v r srf c -> b v r srf () c"), |
| xbound = xbound, |
| ybound = ybound, |
| zbound = zbound, |
| input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| depth_prob = results_dict["pdf"], |
| depth_candidates = results_dict["depth_candidates"], |
| coordidate = gaussians.C, |
| input_coordidate = coordinates, |
| media_val = media_val, |
| scale_factor = scale_factor, |
| ) |
| except Exception as e: |
| import traceback; traceback.print_exc() |
| raise |
| |
| |
|
|
| if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
| intermediate_depth = depth_preds[0] |
| intermediate_voxel_feature = project_features_to_3d( |
| context["intrinsics"], |
| context["extrinsics"], |
| out, |
| depth=intermediate_depth, |
| xbound=xbound, |
| ybound=ybound, |
| zbound=zbound, b=b,v=v |
| ) |
| |
| intermediate_gaussians = self.gaussian_head(intermediate_voxel_feature) |
|
|
| gaussian_params = rearrange(intermediate_gaussians, "b k d h w -> b () k (d h w)") |
| gaussian_params = rearrange(gaussian_params, "b srf k n -> b srf n k") |
| |
| |
| intermediate_opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
| intermediate_raw_gaussians = gaussian_params[..., 1:] |
| intermediate_raw_gaussians = rearrange( |
| intermediate_raw_gaussians, |
| "... (srf c) -> ... srf c", |
| srf=self.cfg.num_surfaces, |
| ) |
| |
| |
| intermediate_gaussians = self.gaussian_adapter.forward( |
| extrinsics = context["extrinsics"], |
| intrinsics = context["intrinsics"], |
| opacities = intermediate_opacities, |
| raw_gaussians = rearrange(intermediate_raw_gaussians,"b v r srf c -> b v r srf () c"), |
| depth = intermediate_depth, |
| xbound = xbound, |
| ybound = ybound, |
| zbound = zbound, |
| input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
| ) |
| |
| gaussians.means=torch.cat([intermediate_gaussians.means, gaussians.means], dim=0) |
| gaussians.covariances=torch.cat([intermediate_gaussians.covariances, gaussians.covariances], dim=0) |
| gaussians.harmonics=torch.cat([intermediate_gaussians.harmonics, gaussians.harmonics], dim=0) |
| gaussians.opacities=torch.cat([intermediate_gaussians.opacities, gaussians.opacities], dim=0) |
| |
| |
| points = rearrange( |
| gaussians.means, |
| "b v r srf spp xyz -> (b v r srf spp) xyz" |
| ) |
| |
| |
| |
| |
| |
| |
| gaussians = Gaussians( |
| rearrange( |
| gaussians.means, |
| "b v r srf spp xyz -> b (v r srf spp) xyz", |
| ), |
| rearrange( |
| gaussians.covariances, |
| "b v r srf spp i j -> b (v r srf spp) i j", |
| ), |
| rearrange( |
| gaussians.harmonics, |
| "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", |
| ), |
| rearrange( |
| gaussians.opacities, |
| "b v r srf spp -> b (v r srf spp)", |
| ), |
| ) |
|
|
| if self.cfg.return_depth: |
| |
| depths = torch.cat(depth_preds, dim=0) |
| |
|
|
| return { |
| "gaussians": gaussians, |
| "depths": depths |
| } |
|
|
| return gaussians |
|
|
| def get_data_shim(self) -> DataShim: |
| def data_shim(batch: BatchedExample) -> BatchedExample: |
| batch = apply_patch_shim( |
| batch, |
| patch_size=self.cfg.shim_patch_size |
| * self.cfg.downscale_factor, |
| ) |
|
|
| return batch |
|
|
| return data_shim |
|
|
| @property |
| def sampler(self): |
| return None |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
|
|