from dataclasses import dataclass from typing import Literal, Optional, List import torch from einops import rearrange, repeat from jaxtyping import Float from torch import Tensor, nn import MinkowskiEngine as ME import torch.nn.init as init from ...dataset.shims.patch_shim import apply_patch_shim from ...dataset.types import BatchedExample, DataShim from ...geometry.projection import sample_image_grid from ..types import Gaussians ########形成高斯点######### # from .common.gaussian_adapter_revise import GaussianAdapter_revise, GaussianAdapterCfg from .common.guassian_adapter_depth import GaussianAdapter_depth, GaussianAdapterCfg from .encoder import Encoder from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg import torchvision.transforms as T import torch.nn.functional as F from .unimatch.mv_unimatch import MultiViewUniMatch from .unimatch.dpt_head import DPTHead from .common.voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_features_to_voxel, adapte_project_features_to_3d from .common.me_fea import project_features_to_me from ...geometry.projection import get_world_rays from .common.sparse_net import SparseGaussianHead, SparseUNetWithAttention from .common.mink_resnet import MultiScaleSparseHead from ...test.export_ply import save_point_cloud_to_ply # import debugpy # try: # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 # debugpy.listen(("localhost", 9326)) # print("Waiting for debugger attach") # debugpy.wait_for_client() # except Exception as e: # pass # 内存打印工具(轻量) def print_mem(tag: str = ""): if not torch.cuda.is_available(): print(f"[MEM] {tag} - no CUDA") return allocated = torch.cuda.memory_allocated() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2 print(f"[MEM] {tag} | allocated={allocated:.1f} MB reserved={reserved:.1f} MB") @dataclass class EncoderDepthSplatCfg: name: Literal["depthsplat"] d_feature: int num_depth_candidates: int num_surfaces: int visualizer: EncoderVisualizerDepthSplatCfg gaussian_adapter: GaussianAdapterCfg gaussians_per_pixel: int unimatch_weights_path: str | None downscale_factor: int shim_patch_size: int multiview_trans_attn_split: int costvolume_unet_feat_dim: int costvolume_unet_channel_mult: List[int] costvolume_unet_attn_res: List[int] depth_unet_feat_dim: int depth_unet_attn_res: List[int] depth_unet_channel_mult: List[int] # mv_unimatch num_scales: int upsample_factor: int lowest_feature_resolution: int depth_unet_channels: int grid_sample_disable_cudnn: bool # depthsplat color branch large_gaussian_head: bool color_large_unet: bool init_sh_input_img: bool feature_upsampler_channels: int gaussian_regressor_channels: int # loss config supervise_intermediate_depth: bool return_depth: bool # only depth train_depth_only: bool # monodepth config monodepth_vit_type: str # multi-view matching local_mv_match: int class EncoderDepthSplat_test(Encoder[EncoderDepthSplatCfg]): def __init__(self, cfg: EncoderDepthSplatCfg) -> None: super().__init__(cfg) self.depth_predictor = MultiViewUniMatch( num_scales=cfg.num_scales, upsample_factor=cfg.upsample_factor, lowest_feature_resolution=cfg.lowest_feature_resolution, vit_type=cfg.monodepth_vit_type, unet_channels=cfg.depth_unet_channels, grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, ) if self.cfg.train_depth_only: return # upsample features to the original resolution model_configs = { 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, } self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], downsample_factor=cfg.upsample_factor, return_feature=True, num_scales=cfg.num_scales, ) feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] # gaussians adapter self.gaussian_adapter = GaussianAdapter_depth(cfg.gaussian_adapter) # concat(img, depth, match_prob, features) in_channels = 3 + 1 + 1 + feature_upsampler_channels channels = self.cfg.gaussian_regressor_channels # conv regressor modules = [ nn.Conv2d(in_channels, channels, 3, 1, 1), nn.GELU(), nn.Conv2d(channels, channels, 3, 1, 1), ] self.gaussian_regressor = nn.Sequential(*modules) # predict gaussian parameters: scale, q, sh, offset, opacity # num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1 # 34 + 2(x,y) + 1(o) = 37 num_gaussian_parameters = self.gaussian_adapter.d_in + 3 + 1 # 34 + 3(x,y,z) + 1(o) = 38 # num_gaussian_parameters = self.gaussian_adapter.d_in + 1 # 34 + + 1(o) = 35 # concat(img, features, regressor_out, match_prob) in_channels = 3 + feature_upsampler_channels + channels + 1 #3D稀疏UNet self.spare_unet =SparseUNetWithAttention( in_channels=in_channels, out_channels=in_channels, num_blocks=3, use_attention=False ) # 创建高斯头 self.gaussian_head = SparseGaussianHead(in_channels, num_gaussian_parameters) # self.depth_fuse = nn.Sequential(nn.Conv2d(2, 4, kernel_size=1, padding=0), # nn.ReLU(), # nn.Conv2d(4, 1, kernel_size=1, padding=0) # ) # # —— 初始化 depth_fuse 中所有 Conv2d 的权重和偏置 —— # # for m in self.depth_fuse.modules(): # if isinstance(m, nn.Conv2d): # init.constant_(m.weight, 0.5) # if m.bias is not None: # init.constant_(m.bias, 0.5) #######体素分辨率预测###### # self.feature_extractor = nn.Sequential( # nn.Conv2d(1, 2, 3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # 1/2 # nn.Conv2d(2, 4, 3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # 1/4 # nn.Conv2d(4, 8, 3, padding=1), # nn.ReLU(), # nn.AdaptiveAvgPool2d((1, 1)) # 全局特征聚合 # ) # # 回归预测头 # self.regressor = nn.Sequential( # nn.Flatten(), # nn.Linear(8, 4), # nn.ReLU(), # nn.Linear(4, 1), # nn.Sigmoid() # 输出0-1范围 # ) # # —— 初始化 depth_fuse 中所有 Conv2d 的权重和偏置 —— # # for m in self.feature_extractor.modules(): # if isinstance(m, nn.Conv2d): # init.constant_(m.weight, 0.5) # if m.bias is not None: # init.constant_(m.bias, 0.5) # 输出缩放参数 (0.01 + 0.04*sigmoid_output) self.scale = 0.04 self.shift = 0.01 def forward( self, context: dict, global_step: int, deterministic: bool = False, visualization_dump: Optional[dict] = None, scene_names: Optional[list] = None, ues_voxelnet: bool = True, ): device = context["image"].device b, v, _, h, w = context["image"].shape if v > 3: with torch.no_grad(): xyzs = context["extrinsics"][:, :, :3, -1].detach() cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) cameras_dist_index = torch.argsort(cameras_dist_matrix) cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] else: cameras_dist_index = None results_dict = self.depth_predictor( context["image"], attn_splits_list=[2], min_depth=1. / context["far"], max_depth=1. / context["near"], intrinsics=context["intrinsics"], extrinsics=context["extrinsics"], nn_matrix=cameras_dist_index, ) # list of [B, V, H, W], with all the intermediate depths depth_preds = results_dict['depth_preds'] depth = depth_preds[-1] ########预测体素分辨率########### # if ues_voxelnet: # # [B*V, 1, H, W] -> [B*V, 16, 1, 1] # depth_features = self.feature_extractor(depth_fused) # # 回归预测 [B*V, 1] # resolution_raw = self.regressor(depth_features) # resolution_mean = resolution_raw.mean().unsqueeze(0) # # 缩放到目标范围 [0.01, 0.05] # voxel_resolution = self.shift + self.scale * resolution_mean # voxel_resolution = voxel_resolution.item() # print(f"预测体素分辨率: {voxel_resolution} m") # else: voxel_resolution = 0.02 #1mm体素 # voxel_resolution = 0.01 #1cm体素 if self.cfg.train_depth_only: # convert format # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: # supervise all the intermediate depth predictions num_depths = len(depth_preds) # [B, V, H*W, 1, 1] intermediate_depths = torch.cat( depth_preds[:(num_depths - 1)], dim=0) intermediate_depths = rearrange( intermediate_depths, "b v h w -> b v (h w) () ()") # concat in the batch dim depths = torch.cat((intermediate_depths, depths), dim=0) b *= num_depths # return depth prediction for supervision depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] return { "gaussians": None, "depths": depths } # features [BV, C, H, W] features = self.feature_upsampler(results_dict["features_mono_intermediate"], cnn_features=results_dict["features_cnn_all_scales"][::-1], mv_features=results_dict["features_mv"][ 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] ) # match prob from softmax # [BV, D, H, W] in feature resolution match_prob = results_dict['match_probs'][-1] match_prob = torch.max(match_prob, dim=1, keepdim=True)[ 0] # [BV, 1, H, W] match_prob = F.interpolate( match_prob, size=depth.shape[-2:], mode='nearest') # unet input [BV, C, H, W] [6, 101, 256, 448] concat = torch.cat(( rearrange(context["image"], "b v c h w -> (b v) c h w"), rearrange(depth, "b v h w -> (b v) () h w"), match_prob, features, ), dim=1) # [BV, C, H, W] out = self.gaussian_regressor(concat) concat = [out, rearrange(context["image"], "b v c h w -> (b v) c h w"), features, match_prob] # [BV, C, H, W] [6, 164, 256, 448] out = torch.cat(concat, dim=1) sparse_input, aggregated_points, counts = project_features_to_me( context["intrinsics"], context["extrinsics"], out, depth=depth, voxel_resolution=voxel_resolution, b=b, v=v ) sparse_out = self.spare_unet(sparse_input) #3D稀疏UNet # refine with residual if torch.equal(sparse_out.C, sparse_input.C) and sparse_out.F.shape[1] == sparse_input.F.shape[1]: # 创建新的特征张量 new_features = sparse_out.F + sparse_input.F # 创建新的 SparseTensor sparse_out_with_residual = ME.SparseTensor( features=new_features, coordinate_map_key=sparse_out.coordinate_map_key, coordinate_manager=sparse_out.coordinate_manager ) else: # 处理坐标不一致的情况 print("警告:输入和输出坐标不一致,跳过残差连接") sparse_out_with_residual = sparse_out #([1, 128, 80, 80, 80]) -> [N, 38] gaussians = self.gaussian_head(sparse_out_with_residual) # 及时释放不再需要的变量 del sparse_out_with_residual,sparse_out,sparse_input,new_features # [B, V, H*W, 1, 1] depths = rearrange(depth, "b v h w -> b v (h w) () ()") # 输出也是稀疏张量 print(f"输出稀疏张量: {gaussians.F.shape[0]}个体素") gaussian_params = gaussians.F.unsqueeze(0).unsqueeze(0) #[N, 38] -> [1, 1, N, 38] # 分离不透明度和其他参数 opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) #[1, 1, 256000, 1, 1] raw_gaussians = gaussian_params[..., 1:] #[1, 1, 256000, 37] raw_gaussians = rearrange( raw_gaussians, "... (srf c) -> ... srf c", srf=self.cfg.num_surfaces, ) try: # 将raw_gaussians转换成gaussians参数 gaussians = self.gaussian_adapter.forward( extrinsics = context["extrinsics"], intrinsics = context["intrinsics"], opacities = opacities, raw_gaussians = rearrange(raw_gaussians,"b v r srf c -> b v r srf () c"), input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), #[6, 3, 256, 448] depth = depth, coordidate = gaussians.C, points = aggregated_points, voxel_resolution = voxel_resolution ) except Exception as e: import traceback; traceback.print_exc() raise if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: intermediate_depth = depth_preds[0] #得到voxel_feature intermediate_voxel_feature, median_points, counts = project_features_to_me( context["intrinsics"], context["extrinsics"], out, depth=intermediate_depth, voxel_resolution=voxel_resolution, b=b, v=v ) #############################经过U-net进行finture####################### intermediate_out = self.spare_unet(intermediate_voxel_feature) #3D稀疏UNet # refine with residual if torch.equal(intermediate_out.C, intermediate_voxel_feature.C) and intermediate_out.F.shape[1] == intermediate_voxel_feature.F.shape[1]: # 创建新的特征张量 new_inter_features = intermediate_out.F + intermediate_voxel_feature.F # 创建新的 SparseTensor intermedian_out_with_residual = ME.SparseTensor( features=new_inter_features, coordinate_map_key=intermediate_voxel_feature.coordinate_map_key, coordinate_manager=intermediate_voxel_feature.coordinate_manager ) else: # 处理坐标不一致的情况 print("警告:输入和输出坐标不一致,跳过残差连接") intermedian_out_with_residual = intermediate_voxel_feature #([1, 128, 80, 80, 80]) -> [N, 38] intermediate_gaussians = self.gaussian_head(intermedian_out_with_residual) # 及时释放不再需要的变量 del intermediate_voxel_feature,intermediate_out,intermedian_out_with_residual # intermediate_gaussians = self.gaussian_head(intermediate_voxel_feature) #[N, 38] # print_mem("after media_depth gaussian_head") gaussian_params = intermediate_gaussians.F.unsqueeze(0).unsqueeze(0) #[N, 38] -> [1, 1, N, 38] # 分离不透明度和其他参数 intermediate_opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) #[1, 1, 256000, 1, 1] intermediate_raw_gaussians = gaussian_params[..., 1:] #[1, 1, 256000, 37] intermediate_raw_gaussians = rearrange( intermediate_raw_gaussians, "... (srf c) -> ... srf c", srf=self.cfg.num_surfaces, ) # 将raw_gaussians转换成gaussians参数 intermediate_gaussians = self.gaussian_adapter.forward( extrinsics = context["extrinsics"], intrinsics = context["intrinsics"], opacities = intermediate_opacities, raw_gaussians = rearrange(intermediate_raw_gaussians,"b v r srf c -> b v r srf () c"), input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), #[6, 3, 256, 448] depth = intermediate_depth, coordidate = intermediate_gaussians.C, points = median_points, voxel_resolution = voxel_resolution ) intermediate_gaussians = Gaussians( rearrange( intermediate_gaussians.means, #[2, 1, 256000, 1, 1, 3] "b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 256000, 3] ), rearrange( intermediate_gaussians.covariances, #[2, 1, 256000, 1, 1, 3, 3] "b v r srf spp i j -> b (v r srf spp) i j", #[2, 256000, 3, 3] ), rearrange( intermediate_gaussians.harmonics, #[2, 1, 256000, 1, 1, 3, 9] "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 256000, 3, 9] ), rearrange( intermediate_gaussians.opacities, #[2, 1, 256000, 1, 1] "b v r srf spp -> b (v r srf spp)", #[2, 256000] ), ) else: intermediate_gaussians = None gaussians = Gaussians( rearrange( gaussians.means, #[2, 1, 256000, 1, 1, 3] "b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 256000, 3] ), rearrange( gaussians.covariances, #[2, 1, 256000, 1, 1, 3, 3] "b v r srf spp i j -> b (v r srf spp) i j", #[2, 256000, 3, 3] ), rearrange( gaussians.harmonics, #[2, 1, 256000, 1, 1, 3, 9] "b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 256000, 3, 9] ), rearrange( gaussians.opacities, #[2, 1, 256000, 1, 1] "b v r srf spp -> b (v r srf spp)", #[2, 256000] ), ) # print_mem("end forward") if self.cfg.return_depth: # return depth prediction for supervision # depths = torch.cat(depth_preds, dim=0) depths = rearrange( depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w ).squeeze(-1).squeeze(-1) # print(depths.shape) # [B, V, H, W] [2, 6, 256, 448] if intermediate_gaussians is not None: return { "gaussians": gaussians, "depths": depths, "intermediate_gaussians": intermediate_gaussians } else: return { "gaussians": gaussians, "depths": depths, } return gaussians def get_data_shim(self) -> DataShim: def data_shim(batch: BatchedExample) -> BatchedExample: batch = apply_patch_shim( batch, patch_size=self.cfg.shim_patch_size * self.cfg.downscale_factor, ) return batch return data_shim @property def sampler(self): return None