|
|
from dataclasses import dataclass |
|
|
from typing import Literal, Optional, List |
|
|
|
|
|
import torch |
|
|
from einops import rearrange, repeat |
|
|
from jaxtyping import Float |
|
|
from torch import Tensor, nn |
|
|
import MinkowskiEngine as ME |
|
|
import torch.nn.init as init |
|
|
|
|
|
from ...dataset.shims.patch_shim import apply_patch_shim |
|
|
from ...dataset.types import BatchedExample, DataShim |
|
|
from ...geometry.projection import sample_image_grid |
|
|
from ..types import Gaussians |
|
|
|
|
|
|
|
|
|
|
|
from .common.guassian_adapter_depth import GaussianAdapter_depth, GaussianAdapterCfg |
|
|
|
|
|
|
|
|
from .encoder import Encoder |
|
|
from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg |
|
|
|
|
|
import torchvision.transforms as T |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .unimatch.mv_unimatch import MultiViewUniMatch |
|
|
from .unimatch.dpt_head import DPTHead |
|
|
|
|
|
from .common.voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_features_to_voxel, adapte_project_features_to_3d |
|
|
from .common.me_fea import project_features_to_me |
|
|
|
|
|
from ...geometry.projection import get_world_rays |
|
|
from .common.sparse_net import SparseGaussianHead, SparseUNetWithAttention |
|
|
from .common.mink_resnet import MultiScaleSparseHead |
|
|
|
|
|
from ...test.export_ply import save_point_cloud_to_ply |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_mem(tag: str = ""): |
|
|
if not torch.cuda.is_available(): |
|
|
print(f"[MEM] {tag} - no CUDA") |
|
|
return |
|
|
allocated = torch.cuda.memory_allocated() / 1024**2 |
|
|
reserved = torch.cuda.memory_reserved() / 1024**2 |
|
|
print(f"[MEM] {tag} | allocated={allocated:.1f} MB reserved={reserved:.1f} MB") |
|
|
|
|
|
@dataclass |
|
|
class EncoderDepthSplatCfg: |
|
|
name: Literal["depthsplat"] |
|
|
d_feature: int |
|
|
num_depth_candidates: int |
|
|
num_surfaces: int |
|
|
visualizer: EncoderVisualizerDepthSplatCfg |
|
|
gaussian_adapter: GaussianAdapterCfg |
|
|
gaussians_per_pixel: int |
|
|
unimatch_weights_path: str | None |
|
|
downscale_factor: int |
|
|
shim_patch_size: int |
|
|
multiview_trans_attn_split: int |
|
|
costvolume_unet_feat_dim: int |
|
|
costvolume_unet_channel_mult: List[int] |
|
|
costvolume_unet_attn_res: List[int] |
|
|
depth_unet_feat_dim: int |
|
|
depth_unet_attn_res: List[int] |
|
|
depth_unet_channel_mult: List[int] |
|
|
|
|
|
|
|
|
num_scales: int |
|
|
upsample_factor: int |
|
|
lowest_feature_resolution: int |
|
|
depth_unet_channels: int |
|
|
grid_sample_disable_cudnn: bool |
|
|
|
|
|
|
|
|
large_gaussian_head: bool |
|
|
color_large_unet: bool |
|
|
init_sh_input_img: bool |
|
|
feature_upsampler_channels: int |
|
|
gaussian_regressor_channels: int |
|
|
|
|
|
|
|
|
supervise_intermediate_depth: bool |
|
|
return_depth: bool |
|
|
|
|
|
|
|
|
train_depth_only: bool |
|
|
|
|
|
|
|
|
monodepth_vit_type: str |
|
|
|
|
|
|
|
|
local_mv_match: int |
|
|
|
|
|
|
|
|
class EncoderDepthSplat_test(Encoder[EncoderDepthSplatCfg]): |
|
|
def __init__(self, cfg: EncoderDepthSplatCfg) -> None: |
|
|
super().__init__(cfg) |
|
|
|
|
|
self.depth_predictor = MultiViewUniMatch( |
|
|
num_scales=cfg.num_scales, |
|
|
upsample_factor=cfg.upsample_factor, |
|
|
lowest_feature_resolution=cfg.lowest_feature_resolution, |
|
|
vit_type=cfg.monodepth_vit_type, |
|
|
unet_channels=cfg.depth_unet_channels, |
|
|
grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, |
|
|
) |
|
|
|
|
|
if self.cfg.train_depth_only: |
|
|
return |
|
|
|
|
|
|
|
|
model_configs = { |
|
|
'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, |
|
|
'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, |
|
|
'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, |
|
|
} |
|
|
|
|
|
self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type], |
|
|
downsample_factor=cfg.upsample_factor, |
|
|
return_feature=True, |
|
|
num_scales=cfg.num_scales, |
|
|
) |
|
|
feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"] |
|
|
|
|
|
|
|
|
self.gaussian_adapter = GaussianAdapter_depth(cfg.gaussian_adapter) |
|
|
|
|
|
|
|
|
in_channels = 3 + 1 + 1 + feature_upsampler_channels |
|
|
channels = self.cfg.gaussian_regressor_channels |
|
|
|
|
|
|
|
|
modules = [ |
|
|
nn.Conv2d(in_channels, channels, 3, 1, 1), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(channels, channels, 3, 1, 1), |
|
|
] |
|
|
|
|
|
self.gaussian_regressor = nn.Sequential(*modules) |
|
|
|
|
|
|
|
|
|
|
|
num_gaussian_parameters = self.gaussian_adapter.d_in + 3 + 1 |
|
|
|
|
|
|
|
|
|
|
|
in_channels = 3 + feature_upsampler_channels + channels + 1 |
|
|
|
|
|
|
|
|
self.spare_unet =SparseUNetWithAttention( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
num_blocks=3, |
|
|
use_attention=False |
|
|
) |
|
|
|
|
|
|
|
|
self.gaussian_head = SparseGaussianHead(in_channels, num_gaussian_parameters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scale = 0.04 |
|
|
self.shift = 0.01 |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
context: dict, |
|
|
global_step: int, |
|
|
deterministic: bool = False, |
|
|
visualization_dump: Optional[dict] = None, |
|
|
scene_names: Optional[list] = None, |
|
|
ues_voxelnet: bool = True, |
|
|
): |
|
|
device = context["image"].device |
|
|
b, v, _, h, w = context["image"].shape |
|
|
|
|
|
if v > 3: |
|
|
with torch.no_grad(): |
|
|
xyzs = context["extrinsics"][:, :, :3, -1].detach() |
|
|
cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) |
|
|
cameras_dist_index = torch.argsort(cameras_dist_matrix) |
|
|
|
|
|
cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] |
|
|
else: |
|
|
cameras_dist_index = None |
|
|
|
|
|
|
|
|
results_dict = self.depth_predictor( |
|
|
context["image"], |
|
|
attn_splits_list=[2], |
|
|
min_depth=1. / context["far"], |
|
|
max_depth=1. / context["near"], |
|
|
intrinsics=context["intrinsics"], |
|
|
extrinsics=context["extrinsics"], |
|
|
nn_matrix=cameras_dist_index, |
|
|
) |
|
|
|
|
|
|
|
|
depth_preds = results_dict['depth_preds'] |
|
|
|
|
|
depth = depth_preds[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
voxel_resolution = 0.02 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.cfg.train_depth_only: |
|
|
|
|
|
|
|
|
depths = rearrange(depth, "b v h w -> b v (h w) () ()") |
|
|
|
|
|
if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
|
|
|
|
|
num_depths = len(depth_preds) |
|
|
|
|
|
|
|
|
intermediate_depths = torch.cat( |
|
|
depth_preds[:(num_depths - 1)], dim=0) |
|
|
intermediate_depths = rearrange( |
|
|
intermediate_depths, "b v h w -> b v (h w) () ()") |
|
|
|
|
|
|
|
|
depths = torch.cat((intermediate_depths, depths), dim=0) |
|
|
|
|
|
b *= num_depths |
|
|
|
|
|
|
|
|
depths = rearrange( |
|
|
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w |
|
|
).squeeze(-1).squeeze(-1) |
|
|
|
|
|
|
|
|
return { |
|
|
"gaussians": None, |
|
|
"depths": depths |
|
|
} |
|
|
|
|
|
|
|
|
features = self.feature_upsampler(results_dict["features_mono_intermediate"], |
|
|
cnn_features=results_dict["features_cnn_all_scales"][::-1], |
|
|
mv_features=results_dict["features_mv"][ |
|
|
0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
match_prob = results_dict['match_probs'][-1] |
|
|
match_prob = torch.max(match_prob, dim=1, keepdim=True)[ |
|
|
0] |
|
|
match_prob = F.interpolate( |
|
|
match_prob, size=depth.shape[-2:], mode='nearest') |
|
|
|
|
|
|
|
|
|
|
|
concat = torch.cat(( |
|
|
rearrange(context["image"], "b v c h w -> (b v) c h w"), |
|
|
rearrange(depth, "b v h w -> (b v) () h w"), |
|
|
match_prob, |
|
|
features, |
|
|
), dim=1) |
|
|
|
|
|
out = self.gaussian_regressor(concat) |
|
|
concat = [out, |
|
|
rearrange(context["image"], |
|
|
"b v c h w -> (b v) c h w"), |
|
|
features, |
|
|
match_prob] |
|
|
|
|
|
out = torch.cat(concat, dim=1) |
|
|
|
|
|
sparse_input, aggregated_points, counts = project_features_to_me( |
|
|
context["intrinsics"], |
|
|
context["extrinsics"], |
|
|
out, |
|
|
depth=depth, |
|
|
voxel_resolution=voxel_resolution, |
|
|
b=b, v=v |
|
|
) |
|
|
|
|
|
sparse_out = self.spare_unet(sparse_input) |
|
|
|
|
|
if torch.equal(sparse_out.C, sparse_input.C) and sparse_out.F.shape[1] == sparse_input.F.shape[1]: |
|
|
|
|
|
new_features = sparse_out.F + sparse_input.F |
|
|
|
|
|
|
|
|
sparse_out_with_residual = ME.SparseTensor( |
|
|
features=new_features, |
|
|
coordinate_map_key=sparse_out.coordinate_map_key, |
|
|
coordinate_manager=sparse_out.coordinate_manager |
|
|
) |
|
|
else: |
|
|
|
|
|
print("警告:输入和输出坐标不一致,跳过残差连接") |
|
|
sparse_out_with_residual = sparse_out |
|
|
|
|
|
|
|
|
gaussians = self.gaussian_head(sparse_out_with_residual) |
|
|
|
|
|
|
|
|
del sparse_out_with_residual,sparse_out,sparse_input,new_features |
|
|
|
|
|
|
|
|
|
|
|
depths = rearrange(depth, "b v h w -> b v (h w) () ()") |
|
|
|
|
|
|
|
|
|
|
|
print(f"输出稀疏张量: {gaussians.F.shape[0]}个体素") |
|
|
|
|
|
gaussian_params = gaussians.F.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
|
|
raw_gaussians = gaussian_params[..., 1:] |
|
|
raw_gaussians = rearrange( |
|
|
raw_gaussians, |
|
|
"... (srf c) -> ... srf c", |
|
|
srf=self.cfg.num_surfaces, |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
gaussians = self.gaussian_adapter.forward( |
|
|
extrinsics = context["extrinsics"], |
|
|
intrinsics = context["intrinsics"], |
|
|
opacities = opacities, |
|
|
raw_gaussians = rearrange(raw_gaussians,"b v r srf c -> b v r srf () c"), |
|
|
input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
|
|
depth = depth, |
|
|
coordidate = gaussians.C, |
|
|
points = aggregated_points, |
|
|
voxel_resolution = voxel_resolution |
|
|
) |
|
|
except Exception as e: |
|
|
import traceback; traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1: |
|
|
intermediate_depth = depth_preds[0] |
|
|
|
|
|
intermediate_voxel_feature, median_points, counts = project_features_to_me( |
|
|
context["intrinsics"], |
|
|
context["extrinsics"], |
|
|
out, |
|
|
depth=intermediate_depth, |
|
|
voxel_resolution=voxel_resolution, |
|
|
b=b, v=v |
|
|
) |
|
|
|
|
|
intermediate_out = self.spare_unet(intermediate_voxel_feature) |
|
|
|
|
|
if torch.equal(intermediate_out.C, intermediate_voxel_feature.C) and intermediate_out.F.shape[1] == intermediate_voxel_feature.F.shape[1]: |
|
|
|
|
|
new_inter_features = intermediate_out.F + intermediate_voxel_feature.F |
|
|
|
|
|
|
|
|
intermedian_out_with_residual = ME.SparseTensor( |
|
|
features=new_inter_features, |
|
|
coordinate_map_key=intermediate_voxel_feature.coordinate_map_key, |
|
|
coordinate_manager=intermediate_voxel_feature.coordinate_manager |
|
|
) |
|
|
else: |
|
|
|
|
|
print("警告:输入和输出坐标不一致,跳过残差连接") |
|
|
intermedian_out_with_residual = intermediate_voxel_feature |
|
|
|
|
|
|
|
|
intermediate_gaussians = self.gaussian_head(intermedian_out_with_residual) |
|
|
|
|
|
|
|
|
del intermediate_voxel_feature,intermediate_out,intermedian_out_with_residual |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gaussian_params = intermediate_gaussians.F.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
intermediate_opacities = gaussian_params[..., :1].sigmoid().unsqueeze(-1) |
|
|
intermediate_raw_gaussians = gaussian_params[..., 1:] |
|
|
intermediate_raw_gaussians = rearrange( |
|
|
intermediate_raw_gaussians, |
|
|
"... (srf c) -> ... srf c", |
|
|
srf=self.cfg.num_surfaces, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
intermediate_gaussians = self.gaussian_adapter.forward( |
|
|
extrinsics = context["extrinsics"], |
|
|
intrinsics = context["intrinsics"], |
|
|
opacities = intermediate_opacities, |
|
|
raw_gaussians = rearrange(intermediate_raw_gaussians,"b v r srf c -> b v r srf () c"), |
|
|
input_images =rearrange(context["image"], "b v c h w -> (b v) c h w"), |
|
|
depth = intermediate_depth, |
|
|
coordidate = intermediate_gaussians.C, |
|
|
points = median_points, |
|
|
voxel_resolution = voxel_resolution |
|
|
) |
|
|
|
|
|
intermediate_gaussians = Gaussians( |
|
|
rearrange( |
|
|
intermediate_gaussians.means, |
|
|
"b v r srf spp xyz -> b (v r srf spp) xyz", |
|
|
), |
|
|
rearrange( |
|
|
intermediate_gaussians.covariances, |
|
|
"b v r srf spp i j -> b (v r srf spp) i j", |
|
|
), |
|
|
rearrange( |
|
|
intermediate_gaussians.harmonics, |
|
|
"b v r srf spp c d_sh -> b (v r srf spp) c d_sh", |
|
|
), |
|
|
rearrange( |
|
|
intermediate_gaussians.opacities, |
|
|
"b v r srf spp -> b (v r srf spp)", |
|
|
), |
|
|
) |
|
|
else: |
|
|
intermediate_gaussians = None |
|
|
|
|
|
|
|
|
|
|
|
gaussians = Gaussians( |
|
|
rearrange( |
|
|
gaussians.means, |
|
|
"b v r srf spp xyz -> b (v r srf spp) xyz", |
|
|
), |
|
|
rearrange( |
|
|
gaussians.covariances, |
|
|
"b v r srf spp i j -> b (v r srf spp) i j", |
|
|
), |
|
|
rearrange( |
|
|
gaussians.harmonics, |
|
|
"b v r srf spp c d_sh -> b (v r srf spp) c d_sh", |
|
|
), |
|
|
rearrange( |
|
|
gaussians.opacities, |
|
|
"b v r srf spp -> b (v r srf spp)", |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if self.cfg.return_depth: |
|
|
|
|
|
|
|
|
depths = rearrange( |
|
|
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w |
|
|
).squeeze(-1).squeeze(-1) |
|
|
|
|
|
|
|
|
if intermediate_gaussians is not None: |
|
|
return { |
|
|
"gaussians": gaussians, |
|
|
"depths": depths, |
|
|
"intermediate_gaussians": intermediate_gaussians |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"gaussians": gaussians, |
|
|
"depths": depths, |
|
|
} |
|
|
|
|
|
return gaussians |
|
|
|
|
|
def get_data_shim(self) -> DataShim: |
|
|
def data_shim(batch: BatchedExample) -> BatchedExample: |
|
|
batch = apply_patch_shim( |
|
|
batch, |
|
|
patch_size=self.cfg.shim_patch_size |
|
|
* self.cfg.downscale_factor, |
|
|
) |
|
|
|
|
|
return batch |
|
|
|
|
|
return data_shim |
|
|
|
|
|
@property |
|
|
def sampler(self): |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|