from dataclasses import dataclass from typing import Literal, Optional, List import torch from einops import rearrange, repeat from jaxtyping import Float from torch import Tensor, nn import MinkowskiEngine as ME import torch.nn.init as init from ...dataset.shims.patch_shim import apply_patch_shim from ...dataset.types import BatchedExample, DataShim from ...geometry.projection import sample_image_grid from ..types import Gaussians ########形成高斯点######### # from .common.gaussian_adapter_revise import GaussianAdapter_revise, GaussianAdapterCfg from .common.guassian_adapter_depth import GaussianAdapter_depth, GaussianAdapterCfg from .encoder import Encoder from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg import torchvision.transforms as T import torch.nn.functional as F from .unimatch.mv_unimatch import MultiViewUniMatch from .unimatch.dpt_head import DPTHead from .common.voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_features_to_voxel, adapte_project_features_to_3d from .common.me_fea import project_features_to_me from ...geometry.projection import get_world_rays from .common.sparse_net import SparseGaussianHead, SparseUNetWithAttention from ...test.export_ply import save_point_cloud_to_ply ##加入depthanything## from ...test.try_depthanything import DepthAnythingWrapper from ...test.visual import save_depth_images, save_output_images import gc import time # import debugpy # try: # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 # debugpy.listen(("localhost", 9479)) # print("Waiting for debugger attach") # debugpy.wait_for_client() # except Exception as e: # pass # 内存打印工具(轻量) def print_mem(tag: str = ""): if not torch.cuda.is_available(): print(f"[MEM] {tag} - no CUDA") return allocated = torch.cuda.memory_allocated() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2 print(f"[MEM] {tag} | allocated={allocated:.1f} MB reserved={reserved:.1f} MB") @dataclass class EncoderDepthSplatCfg: name: Literal["depthsplat"] d_feature: int num_depth_candidates: int num_surfaces: int visualizer: EncoderVisualizerDepthSplatCfg gaussian_adapter: GaussianAdapterCfg gaussians_per_pixel: int unimatch_weights_path: str | None downscale_factor: int shim_patch_size: int multiview_trans_attn_split: int costvolume_unet_feat_dim: int costvolume_unet_channel_mult: List[int] costvolume_unet_attn_res: List[int] depth_unet_feat_dim: int depth_unet_attn_res: List[int] depth_unet_channel_mult: List[int] # mv_unimatch num_scales: int upsample_factor: int lowest_feature_resolution: int depth_unet_channels: int grid_sample_disable_cudnn: bool # depthsplat color branch large_gaussian_head: bool color_large_unet: bool init_sh_input_img: bool feature_upsampler_channels: int gaussian_regressor_channels: int # loss config supervise_intermediate_depth: bool return_depth: bool # only depth train_depth_only: bool # monodepth config monodepth_vit_type: str # multi-view matching local_mv_match: int class EncoderDepthSplat_test(Encoder[EncoderDepthSplatCfg]): def __init__(self, cfg: EncoderDepthSplatCfg) -> None: super().__init__(cfg) self.depth_predictor = MultiViewUniMatch( num_scales=cfg.num_scales, upsample_factor=cfg.upsample_factor, lowest_feature_resolution=cfg.lowest_feature_resolution, vit_type=cfg.monodepth_vit_type, unet_channels=cfg.depth_unet_channels, grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, ) if self.cfg.train_depth_only: return # upsample features to the original resolution model_configs = { 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, } self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], downsample_factor=cfg.upsample_factor, return_feature=True, num_scales=cfg.num_scales, ) feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] # gaussians adapter self.gaussian_adapter = GaussianAdapter_depth(cfg.gaussian_adapter) # concat(img, depth, match_prob, features) in_channels = 3 + 1 + 1 + feature_upsampler_channels channels = self.cfg.gaussian_regressor_channels # conv regressor modules = [ nn.Conv2d(in_channels, channels, 3, 1, 1), nn.GELU(), nn.Conv2d(channels, channels, 3, 1, 1), ] self.gaussian_regressor = nn.Sequential(*modules) # predict gaussian parameters: scale, q, sh, offset, opacity # num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1 # 34 + 2(x,y) + 1(o) = 37 num_gaussian_parameters = self.gaussian_adapter.d_in + 3 + 1 # 34 + 3(x,y,z) + 1(o) = 38 # num_gaussian_parameters = self.gaussian_adapter.d_in + 1 # 34 + + 1(o) = 35 # concat(img, features, regressor_out, match_prob) in_channels = 3 + feature_upsampler_channels + channels + 1 # 创建高斯头 self.gaussian_head = SparseGaussianHead(in_channels, num_gaussian_parameters) # self.gaussian_head = SparseUNetWithAttention(in_channels=in_channels, out_channels=num_gaussian_parameters, num_blocks=3, use_attention=True) ##########depthanything########## encoder = 'vitb' checkpoint_path = '/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/pretrained/depth_anything_vitb14.pth' # self.depth_anything = DepthAnythingWrapper(encoder,checkpoint_path) # self.depth_fuse = nn.Sequential(nn.Conv2d(2, 4, kernel_size=1, padding=0), # nn.ReLU(), # nn.Conv2d(4, 1, kernel_size=1, padding=0) # ) # # —— 初始化 depth_fuse 中所有 Conv2d 的权重和偏置 —— # # for m in self.depth_fuse.modules(): # if isinstance(m, nn.Conv2d): # init.constant_(m.weight, 0.5) # if m.bias is not None: # init.constant_(m.bias, 0.5) #######体素分辨率预测###### # self.feature_extractor = nn.Sequential( # nn.Conv2d(1, 2, 3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # 1/2 # nn.Conv2d(2, 4, 3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # 1/4 # nn.Conv2d(4, 8, 3, padding=1), # nn.ReLU(), # nn.AdaptiveAvgPool2d((1, 1)) # 全局特征聚合 # ) # # 回归预测头 # self.regressor = nn.Sequential( # nn.Flatten(), # nn.Linear(8, 4), # nn.ReLU(), # nn.Linear(4, 1), # nn.Sigmoid() # 输出0-1范围 # ) # # —— 初始化 depth_fuse 中所有 Conv2d 的权重和偏置 —— # # for m in self.feature_extractor.modules(): # if isinstance(m, nn.Conv2d): # init.constant_(m.weight, 0.5) # if m.bias is not None: # init.constant_(m.bias, 0.5) # 输出缩放参数 (0.01 + 0.04*sigmoid_output) self.scale = 0.04 self.shift = 0.01 # if self.cfg.init_sh_input_img: # nn.init.zeros_(self.gaussian_head[-1].weight[10:]) # nn.init.zeros_(self.gaussian_head[-1].bias[10:]) # # init scale # # first 3: opacity, offset_xy # nn.init.zeros_(self.gaussian_head[-1].weight[3:6]) # nn.init.zeros_(self.gaussian_head[-1].bias[3:6]) def forward( self, context: dict, global_step: int, deterministic: bool = False, visualization_dump: Optional[dict] = None, scene_names: Optional[list] = None, ues_voxelnet: bool = True, ): device = context["image"].device b, v, _, h, w = context["image"].shape if v > 3: with torch.no_grad(): xyzs = context["extrinsics"][:, :, :3, -1].detach() cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) cameras_dist_index = torch.argsort(cameras_dist_matrix) cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] else: cameras_dist_index = None # print_mem("forward start") # depth prediction results_dict = self.depth_predictor( context["image"], attn_splits_list=[2], min_depth=1. / context["far"], max_depth=1. / context["near"], intrinsics=context["intrinsics"], extrinsics=context["extrinsics"], nn_matrix=cameras_dist_index, ) # print_mem("after depth_predictor") # ################使用 depth_anything 进行深度预测################# # depth_anything = self.depth_anything(context["image"]) # [V, B, H, W]:[6, 1, 256, 448] # depth_anything = depth_anything.permute(1, 0, 2, 3) # [B, V, H, W] # depth = depth_anything #保存RGB图像 rgb_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/rgb_image" # save_output_images(context["image"], rgb_image_path) # list of [B, V, H, W], with all the intermediate depths depth_preds = results_dict['depth_preds'] depth = depth_preds[-1] # [B, V, H, W] # depth_pre = depth_preds[-1] #深度图 # B, V, H, W = depth_pre.shape # depth_fused = depth_pre.reshape(-1, 1, H, W) # depth = depth_pre ########预测深度########### # depth_concat = torch.stack([depth_pre, depth_anything], dim=2) # [B, V, 2, H, W] # depth_concat = depth_concat.view(B * V, 2, H, W) # [B*V, 2, H, W] # depth_fused = self.depth_fuse(depth_concat) # [B*V, 1, H, W] # depth_fused_cla = depth_fused.clamp(min=0.5, max=200) # depth = depth_fused_cla.view(B, V, H, W) ########预测体素分辨率########### # if ues_voxelnet: # # [B*V, 1, H, W] -> [B*V, 16, 1, 1] # depth_features = self.feature_extractor(depth_fused) # # 回归预测 [B*V, 1] # resolution_raw = self.regressor(depth_features) # resolution_mean = resolution_raw.mean().unsqueeze(0) # # 缩放到目标范围 [0.01, 0.05] # voxel_resolution = self.shift + self.scale * resolution_mean # voxel_resolution = voxel_resolution.item() # print(f"预测体素分辨率: {voxel_resolution} m") # else: voxel_resolution = 0.001 #1mm体素 # voxel_resolution = 0.01 #1cm体素 #保存深度图 depth_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/depth_image" # save_depth_images(depth, depth_image_path) if self.cfg.train_depth_only: # convert format # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: # supervise all the intermediate depth predictions num_depths = len(depth_preds) # [B, V, H*W, 1, 1] intermediate_depths = torch.cat( depth_preds[:(num_depths - 1)], dim=0) intermediate_depths = rearrange( intermediate_depths, "b v h w -> b v (h w) () ()") # concat in the batch dim depths = torch.cat((intermediate_depths, depths), dim=0) b *= num_depths # return depth prediction for supervision depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] return { "gaussians": None, "depths": depths } # features [BV, C, H, W] features = self.feature_upsampler(results_dict["features_mono_intermediate"], cnn_features=results_dict["features_cnn_all_scales"][::-1], mv_features=results_dict["features_mv"][ 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] ) # print_mem("after feature_upsampler") # match prob from softmax # [BV, D, H, W] in feature resolution match_prob = results_dict['match_probs'][-1] match_prob = torch.max(match_prob, dim=1, keepdim=True)[ 0] # [BV, 1, H, W] match_prob = F.interpolate( match_prob, size=depth.shape[-2:], mode='nearest') # unet input [BV, C, H, W] [6, 101, 256, 448] concat = torch.cat(( rearrange(context["image"], "b v c h w -> (b v) c h w"), rearrange(depth, "b v h w -> (b v) () h w"), match_prob, features, ), dim=1) # [BV, C, H, W] out = self.gaussian_regressor(concat) concat = [out, rearrange(context["image"], "b v c h w -> (b v) c h w"), features, match_prob] # [BV, C, H, W] [6, 164, 256, 448] out = torch.cat(concat, dim=1) # print_mem("before project_features_to_me") # # print(f"输入稀疏张量: {len(coordinates)}个体素") sparse_input, aggregated_points, counts = project_features_to_me( context["intrinsics"], context["extrinsics"], out, depth=depth, voxel_resolution=voxel_resolution, b=b, v=v ) # print_mem("after project_features_to_me") #([1, 128, 80, 80, 80]) -> [N, 38] gaussians = self.gaussian_head(sparse_input) # print_mem("after gaussian_head") # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") # 输出也是稀疏张量 print(f"输出稀疏张量: {gaussians.F.shape[0]}个体素") gaussian_params = gaussians.F.unsqueeze(0).unsqueeze(0) #[N, 38] -> [1, 1, N, 38] # 分离不透明度和其他参数 opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) #[1, 1, 256000, 1, 1] raw_gaussians = gaussian_params[..., 1:] #[1, 1, 256000, 37] raw_gaussians = rearrange( raw_gaussians, "... (srf c) -> ... srf c", srf=self.cfg.num_surfaces, ) try: # 将raw_gaussians转换成gaussians参数 gaussians = self.gaussian_adapter.forward( extrinsics = context["extrinsics"], intrinsics = context["intrinsics"], opacities = opacities, raw_gaussians = rearrange(raw_gaussians,"b v r srf c -> b v r srf () c"), input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), #[6, 3, 256, 448] depth = depth, coordidate = gaussians.C, points = aggregated_points, voxel_resolution = voxel_resolution ) except Exception as e: import traceback; traceback.print_exc() raise # print_mem("after gaussian_adapter") if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: intermediate_depth = depth_preds[0] intermediate_voxel_feature, median_points, counts = project_features_to_me( context["intrinsics"], context["extrinsics"], out, depth=intermediate_depth, voxel_resolution=voxel_resolution, b=b, v=v ) # print_mem("after media_depth project_features_to_me") intermediate_gaussians = self.gaussian_head(intermediate_voxel_feature) #[N, 38] # print_mem("after media_depth gaussian_head") gaussian_params = intermediate_gaussians.F.unsqueeze(0).unsqueeze(0) #[N, 38] -> [1, 1, N, 38] # 分离不透明度和其他参数 intermediate_opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) #[1, 1, 256000, 1, 1] intermediate_raw_gaussians = gaussian_params[..., 1:] #[1, 1, 256000, 37] intermediate_raw_gaussians = rearrange( intermediate_raw_gaussians, "... (srf c) -> ... srf c", srf=self.cfg.num_surfaces, ) # print_mem("before media_depth gaussian_adapter") # 将raw_gaussians转换成gaussians参数 intermediate_gaussians = self.gaussian_adapter.forward( extrinsics = context["extrinsics"], intrinsics = context["intrinsics"], opacities = intermediate_opacities, raw_gaussians = rearrange(intermediate_raw_gaussians,"b v r srf c -> b v r srf () c"), input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), #[6, 3, 256, 448] depth = intermediate_depth, coordidate = intermediate_gaussians.C, points = median_points, voxel_resolution = voxel_resolution ) # print_mem("after media_depth gaussian_adapter") intermediate_gaussians = Gaussians( rearrange( intermediate_gaussians.means, #[2, 1, 256000, 1, 1, 3] "b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 256000, 3] ), rearrange( intermediate_gaussians.covariances, #[2, 1, 256000, 1, 1, 3, 3] "b v r srf spp i j -> b (v r srf spp) i j", #[2, 256000, 3, 3] ), rearrange( intermediate_gaussians.harmonics, #[2, 1, 256000, 1, 1, 3, 9] "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 256000, 3, 9] ), rearrange( intermediate_gaussians.opacities, #[2, 1, 256000, 1, 1] "b v r srf spp -> b (v r srf spp)", #[2, 256000] ), ) else: intermediate_gaussians = None points = rearrange( gaussians.means, #[2, 1, 256000, 1, 1, 3] "b v r srf spp xyz -> (b v r srf spp) xyz" #[N, 3] ) from pathlib import Path # if global_step%1000 == 0: # save_point_cloud_to_ply(points, Path("/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/outputs/project_point_cloud"), "gs_points") # # 强制释放PyTorch未使用的缓存 # torch.cuda.empty_cache() gaussians = Gaussians( rearrange( gaussians.means, #[2, 1, 256000, 1, 1, 3] "b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 256000, 3] ), rearrange( gaussians.covariances, #[2, 1, 256000, 1, 1, 3, 3] "b v r srf spp i j -> b (v r srf spp) i j", #[2, 256000, 3, 3] ), rearrange( gaussians.harmonics, #[2, 1, 256000, 1, 1, 3, 9] "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 256000, 3, 9] ), rearrange( gaussians.opacities, #[2, 1, 256000, 1, 1] "b v r srf spp -> b (v r srf spp)", #[2, 256000] ), ) # print_mem("end forward") if self.cfg.return_depth: # return depth prediction for supervision # depths = torch.cat(depth_preds, dim=0) depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] [2, 6, 256, 448] if intermediate_gaussians is not None: return { "gaussians": gaussians, "depths": depths, "intermediate_gaussians": intermediate_gaussians } else: return { "gaussians": gaussians, "depths": depths, } return gaussians def get_data_shim(self) -> DataShim: def data_shim(batch: BatchedExample) -> BatchedExample: batch = apply_patch_shim( batch, patch_size=self.cfg.shim_patch_size * self.cfg.downscale_factor, ) return batch return data_shim @property def sampler(self): return None