from dataclasses import dataclass from typing import Literal, Optional, List import torch from einops import rearrange from jaxtyping import Float from torch import Tensor, nn from ...dataset.shims.patch_shim import apply_patch_shim from ...dataset.types import BatchedExample, DataShim from ...geometry.projection import sample_image_grid from ..types import Gaussians from .common.gaussian_adapter import GaussianAdapter, GaussianAdapterCfg from .encoder import Encoder from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg import torchvision.transforms as T import torch.nn.functional as F from .unimatch.mv_unimatch import MultiViewUniMatch from .unimatch.dpt_head import DPTHead ##加入depthanything## from ...test.try_depthanything import DepthAnythingWrapper from ...test.visual import save_depth_images, save_output_images from ...test.export_ply import save_point_cloud_to_ply @dataclass class EncoderDepthSplatCfg: name: Literal["depthsplat"] d_feature: int num_depth_candidates: int num_surfaces: int visualizer: EncoderVisualizerDepthSplatCfg gaussian_adapter: GaussianAdapterCfg gaussians_per_pixel: int unimatch_weights_path: str | None downscale_factor: int shim_patch_size: int multiview_trans_attn_split: int costvolume_unet_feat_dim: int costvolume_unet_channel_mult: List[int] costvolume_unet_attn_res: List[int] depth_unet_feat_dim: int depth_unet_attn_res: List[int] depth_unet_channel_mult: List[int] # mv_unimatch num_scales: int upsample_factor: int lowest_feature_resolution: int depth_unet_channels: int grid_sample_disable_cudnn: bool # depthsplat color branch large_gaussian_head: bool color_large_unet: bool init_sh_input_img: bool feature_upsampler_channels: int gaussian_regressor_channels: int # loss config supervise_intermediate_depth: bool return_depth: bool # only depth train_depth_only: bool # monodepth config monodepth_vit_type: str # multi-view matching local_mv_match: int class EncoderDepthSplat(Encoder[EncoderDepthSplatCfg]): def __init__(self, cfg: EncoderDepthSplatCfg) -> None: super().__init__(cfg) self.depth_predictor = MultiViewUniMatch( num_scales=cfg.num_scales, upsample_factor=cfg.upsample_factor, lowest_feature_resolution=cfg.lowest_feature_resolution, vit_type=cfg.monodepth_vit_type, unet_channels=cfg.depth_unet_channels, grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, ) if self.cfg.train_depth_only: return # upsample features to the original resolution model_configs = { 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, } self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], downsample_factor=cfg.upsample_factor, return_feature=True, num_scales=cfg.num_scales, ) feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] # gaussians adapter self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter) # concat(img, depth, match_prob, features) in_channels = 3 + 1 + 1 + feature_upsampler_channels channels = self.cfg.gaussian_regressor_channels # conv regressor modules = [ nn.Conv2d(in_channels, channels, 3, 1, 1), nn.GELU(), nn.Conv2d(channels, channels, 3, 1, 1), ] self.gaussian_regressor = nn.Sequential(*modules) # predict gaussian parameters: scale, q, sh, offset, opacity num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1 # concat(img, features, regressor_out, match_prob) in_channels = 3 + feature_upsampler_channels + channels + 1 self.gaussian_head = nn.Sequential( nn.Conv2d(in_channels, num_gaussian_parameters, 3, 1, 1, padding_mode='replicate'), nn.GELU(), nn.Conv2d(num_gaussian_parameters, num_gaussian_parameters, 3, 1, 1, padding_mode='replicate') ) ##########depthanything########## encoder = 'vitb' checkpoint_path = '/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/pretrained/depth_anything_vitb14.pth' # self.depth_anything = DepthAnythingWrapper(encoder,checkpoint_path) ##########设置高斯头初始化参数############ # if self.cfg.init_sh_input_img: # nn.init.zeros_(self.gaussian_head[-1].weight[10:]) # nn.init.zeros_(self.gaussian_head[-1].bias[10:]) # # init scale # # first 3: opacity, offset_xy # nn.init.zeros_(self.gaussian_head[-1].weight[3:6]) # nn.init.zeros_(self.gaussian_head[-1].bias[3:6]) def forward( self, context: dict, global_step: int, deterministic: bool = False, visualization_dump: Optional[dict] = None, scene_names: Optional[list] = None, ): device = context["image"].device b, v, _, h, w = context["image"].shape if v > 3: with torch.no_grad(): xyzs = context["extrinsics"][:, :, :3, -1].detach() cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) cameras_dist_index = torch.argsort(cameras_dist_matrix) cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] else: cameras_dist_index = None # depth prediction results_dict = self.depth_predictor( context["image"], attn_splits_list=[2], min_depth=1. / context["far"], max_depth=1. / context["near"], intrinsics=context["intrinsics"], extrinsics=context["extrinsics"], nn_matrix=cameras_dist_index, ) # depth prediction预测 # depth_anything = self.depth_anything(context["image"]) # [V, B, H, W]:[6, 1, 256, 448] # depth_anything = depth_anything.permute(1, 0, 2, 3) # [B, V, H, W] # ########验证点云########### # depth = depth_anything # 保存深度图像 depth_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/depth_image" # save_depth_images(depth_anything, depth_image_path) #保存RGB图像 # rgb_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/rgb_image" # save_output_images(context["image"], rgb_image_path) # list of [B, V, H, W], with all the intermediate depths depth_preds = results_dict['depth_preds'] # [B, V, H, W] depth = depth_preds[-1] # 保存深度图像 # depth_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/orin_depth_image" # save_depth_images(depth, depth_image_path) if self.cfg.train_depth_only: # convert format # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: # supervise all the intermediate depth predictions num_depths = len(depth_preds) # [B, V, H*W, 1, 1] intermediate_depths = torch.cat( depth_preds[:(num_depths - 1)], dim=0) intermediate_depths = rearrange( intermediate_depths, "b v h w -> b v (h w) () ()") # concat in the batch dim depths = torch.cat((intermediate_depths, depths), dim=0) b *= num_depths # return depth prediction for supervision depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] return { "gaussians": None, "depths": depths } # features [BV, C, H, W] features = self.feature_upsampler(results_dict["features_mono_intermediate"], cnn_features=results_dict["features_cnn_all_scales"][::-1], mv_features=results_dict["features_mv"][ 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] ) # match prob from softmax # [BV, D, H, W] in feature resolution match_prob = results_dict['match_probs'][-1] match_prob = torch.max(match_prob, dim=1, keepdim=True)[ 0] # [BV, 1, H, W] match_prob = F.interpolate( match_prob, size=depth.shape[-2:], mode='nearest') # unet input [BV, C, H, W] concat = torch.cat(( rearrange(context["image"], "b v c h w -> (b v) c h w"), rearrange(depth, "b v h w -> (b v) () h w"), match_prob, features, ), dim=1) # [BV, C, H, W] out = self.gaussian_regressor(concat) concat = [out, rearrange(context["image"], "b v c h w -> (b v) c h w"), features, match_prob] # [BV, C, H, W] out = torch.cat(concat, dim=1) gaussians = self.gaussian_head(out) # [BV, C, H, W] # [B, V, C, H, W] gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v) # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") # [B, V, H*W, 1, 1] densities = rearrange( match_prob, "(b v) c h w -> b v (c h w) () ()", b=b, v=v) # [B, V, H*W, 37] raw_gaussians = rearrange( gaussians, "b v c h w -> b v (h w) c") if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: # supervise all the intermediate depth predictions num_depths = len(depth_preds) # [B, V, H*W, 1, 1] intermediate_depths = torch.cat( depth_preds[:(num_depths - 1)], dim=0) intermediate_depths = rearrange( intermediate_depths, "b v h w -> b v (h w) () ()") # concat in the batch dim depths = torch.cat((intermediate_depths, depths), dim=0) # shared color head [2B, V, H×W, C] densities = torch.cat([densities] * num_depths, dim=0) raw_gaussians = torch.cat( [raw_gaussians] * num_depths, dim=0) b *= num_depths # [B, V, H*W, 1, 1] opacities = raw_gaussians[..., :1].sigmoid().unsqueeze(-1) raw_gaussians = raw_gaussians[..., 1:] # Convert the features and depths into Gaussians. xy_ray, _ = sample_image_grid((h, w), device) #(x,y) xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") gaussians = rearrange( raw_gaussians, "... (srf c) -> ... srf c", srf=self.cfg.num_surfaces, ) offset_xy = gaussians[..., :2].sigmoid() pixel_size = 1 / \ torch.tensor((w, h), dtype=torch.float32, device=device) xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size sh_input_images = context["image"] if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: context_extrinsics = torch.cat( [context["extrinsics"]] * len(depth_preds), dim=0) context_intrinsics = torch.cat( [context["intrinsics"]] * len(depth_preds), dim=0) gaussians = self.gaussian_adapter.forward( rearrange(context_extrinsics, "b v i j -> b v () () () i j"), rearrange(context_intrinsics, "b v i j -> b v () () () i j"), rearrange(xy_ray, "b v r srf xy -> b v r srf () xy"), depths, opacities, rearrange( gaussians[..., 2:], "b v r srf c -> b v r srf () c", ), (h, w), input_images=sh_input_images.repeat( len(depth_preds), 1, 1, 1, 1) if self.cfg.init_sh_input_img else None, ) else: gaussians = self.gaussian_adapter.forward( rearrange(context["extrinsics"], "b v i j -> b v () () () i j"), rearrange(context["intrinsics"], "b v i j -> b v () () () i j"), rearrange(xy_ray, "b v r srf xy -> b v r srf () xy"), depths, opacities, rearrange( gaussians[..., 2:], "b v r srf c -> b v r srf () c", ), (h, w), input_images=sh_input_images if self.cfg.init_sh_input_img else None, ) # Dump visualizations if needed. if visualization_dump is not None: visualization_dump["depth"] = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ) visualization_dump["scales"] = rearrange( gaussians.scales, "b v r srf spp xyz -> b (v r srf spp) xyz" ) visualization_dump["rotations"] = rearrange( gaussians.rotations, "b v r srf spp xyzw -> b (v r srf spp) xyzw" ) #保存点云 from pathlib import Path all_points = rearrange(gaussians.means[0], "v r srf spp xyz -> (v r srf spp) xyz") # [B*V*N, 3] # save_point_cloud_to_ply(all_points, Path("/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/project_point_cloud"), "all_points") gaussians = Gaussians( rearrange( gaussians.means, #[2, 6, 114688, 1, 1, 3] "b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 688128, 3] ), rearrange( gaussians.covariances, #[2, 6, 114688, 1, 1, 3, 3] "b v r srf spp i j -> b (v r srf spp) i j", #[2, 688128, 3, 3] ), rearrange( gaussians.harmonics, #[2, 6, 114688, 1, 1, 3, 9] "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 688128, 3, 9] ), rearrange( gaussians.opacities, #[2, 6, 114688, 1, 1] "b v r srf spp -> b (v r srf spp)", #[2, 688128] ), ) if self.cfg.return_depth: # return depth prediction for supervision depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] [2, 6, 256, 448] return { "gaussians": gaussians, "depths": depths } return gaussians def get_data_shim(self) -> DataShim: def data_shim(batch: BatchedExample) -> BatchedExample: batch = apply_patch_shim( batch, patch_size=self.cfg.shim_patch_size * self.cfg.downscale_factor, ) return batch return data_shim @property def sampler(self): return None