depthsplat / src /model /encoder /encoder_depthsplat.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
from dataclasses import dataclass
from typing import Literal, Optional, List
import torch
from einops import rearrange
from jaxtyping import Float
from torch import Tensor, nn
from ...dataset.shims.patch_shim import apply_patch_shim
from ...dataset.types import BatchedExample, DataShim
from ...geometry.projection import sample_image_grid
from ..types import Gaussians
from .common.gaussian_adapter import GaussianAdapter, GaussianAdapterCfg
from .encoder import Encoder
from .visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg
import torchvision.transforms as T
import torch.nn.functional as F
from .unimatch.mv_unimatch import MultiViewUniMatch
from .unimatch.dpt_head import DPTHead
##加入depthanything##
from ...test.try_depthanything import DepthAnythingWrapper
from ...test.visual import save_depth_images, save_output_images
from ...test.export_ply import save_point_cloud_to_ply
@dataclass
class EncoderDepthSplatCfg:
name: Literal["depthsplat"]
d_feature: int
num_depth_candidates: int
num_surfaces: int
visualizer: EncoderVisualizerDepthSplatCfg
gaussian_adapter: GaussianAdapterCfg
gaussians_per_pixel: int
unimatch_weights_path: str | None
downscale_factor: int
shim_patch_size: int
multiview_trans_attn_split: int
costvolume_unet_feat_dim: int
costvolume_unet_channel_mult: List[int]
costvolume_unet_attn_res: List[int]
depth_unet_feat_dim: int
depth_unet_attn_res: List[int]
depth_unet_channel_mult: List[int]
# mv_unimatch
num_scales: int
upsample_factor: int
lowest_feature_resolution: int
depth_unet_channels: int
grid_sample_disable_cudnn: bool
# depthsplat color branch
large_gaussian_head: bool
color_large_unet: bool
init_sh_input_img: bool
feature_upsampler_channels: int
gaussian_regressor_channels: int
# loss config
supervise_intermediate_depth: bool
return_depth: bool
# only depth
train_depth_only: bool
# monodepth config
monodepth_vit_type: str
# multi-view matching
local_mv_match: int
class EncoderDepthSplat(Encoder[EncoderDepthSplatCfg]):
def __init__(self, cfg: EncoderDepthSplatCfg) -> None:
super().__init__(cfg)
self.depth_predictor = MultiViewUniMatch(
num_scales=cfg.num_scales,
upsample_factor=cfg.upsample_factor,
lowest_feature_resolution=cfg.lowest_feature_resolution,
vit_type=cfg.monodepth_vit_type,
unet_channels=cfg.depth_unet_channels,
grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn,
)
if self.cfg.train_depth_only:
return
# upsample features to the original resolution
model_configs = {
'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]},
'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]},
}
self.feature_upsampler = DPTHead(**model_configs[cfg.monodepth_vit_type],
downsample_factor=cfg.upsample_factor,
return_feature=True,
num_scales=cfg.num_scales,
)
feature_upsampler_channels = model_configs[cfg.monodepth_vit_type]["features"]
# gaussians adapter
self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter)
# concat(img, depth, match_prob, features)
in_channels = 3 + 1 + 1 + feature_upsampler_channels
channels = self.cfg.gaussian_regressor_channels
# conv regressor
modules = [
nn.Conv2d(in_channels, channels, 3, 1, 1),
nn.GELU(),
nn.Conv2d(channels, channels, 3, 1, 1),
]
self.gaussian_regressor = nn.Sequential(*modules)
# predict gaussian parameters: scale, q, sh, offset, opacity
num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1
# concat(img, features, regressor_out, match_prob)
in_channels = 3 + feature_upsampler_channels + channels + 1
self.gaussian_head = nn.Sequential(
nn.Conv2d(in_channels, num_gaussian_parameters,
3, 1, 1, padding_mode='replicate'),
nn.GELU(),
nn.Conv2d(num_gaussian_parameters,
num_gaussian_parameters, 3, 1, 1, padding_mode='replicate')
)
##########depthanything##########
encoder = 'vitb'
checkpoint_path = '/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/pretrained/depth_anything_vitb14.pth'
# self.depth_anything = DepthAnythingWrapper(encoder,checkpoint_path)
##########设置高斯头初始化参数############
# if self.cfg.init_sh_input_img:
# nn.init.zeros_(self.gaussian_head[-1].weight[10:])
# nn.init.zeros_(self.gaussian_head[-1].bias[10:])
# # init scale
# # first 3: opacity, offset_xy
# nn.init.zeros_(self.gaussian_head[-1].weight[3:6])
# nn.init.zeros_(self.gaussian_head[-1].bias[3:6])
def forward(
self,
context: dict,
global_step: int,
deterministic: bool = False,
visualization_dump: Optional[dict] = None,
scene_names: Optional[list] = None,
):
device = context["image"].device
b, v, _, h, w = context["image"].shape
if v > 3:
with torch.no_grad():
xyzs = context["extrinsics"][:, :, :3, -1].detach()
cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2)
cameras_dist_index = torch.argsort(cameras_dist_matrix)
cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)]
else:
cameras_dist_index = None
# depth prediction
results_dict = self.depth_predictor(
context["image"],
attn_splits_list=[2],
min_depth=1. / context["far"],
max_depth=1. / context["near"],
intrinsics=context["intrinsics"],
extrinsics=context["extrinsics"],
nn_matrix=cameras_dist_index,
)
# depth prediction预测
# depth_anything = self.depth_anything(context["image"]) # [V, B, H, W]:[6, 1, 256, 448]
# depth_anything = depth_anything.permute(1, 0, 2, 3) # [B, V, H, W]
# ########验证点云###########
# depth = depth_anything
# 保存深度图像
depth_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/depth_image"
# save_depth_images(depth_anything, depth_image_path)
#保存RGB图像
# rgb_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/rgb_image"
# save_output_images(context["image"], rgb_image_path)
# list of [B, V, H, W], with all the intermediate depths
depth_preds = results_dict['depth_preds']
# [B, V, H, W]
depth = depth_preds[-1]
# 保存深度图像
# depth_image_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/orin_depth_image"
# save_depth_images(depth, depth_image_path)
if self.cfg.train_depth_only:
# convert format
# [B, V, H*W, 1, 1]
depths = rearrange(depth, "b v h w -> b v (h w) () ()")
if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1:
# supervise all the intermediate depth predictions
num_depths = len(depth_preds)
# [B, V, H*W, 1, 1]
intermediate_depths = torch.cat(
depth_preds[:(num_depths - 1)], dim=0)
intermediate_depths = rearrange(
intermediate_depths, "b v h w -> b v (h w) () ()")
# concat in the batch dim
depths = torch.cat((intermediate_depths, depths), dim=0)
b *= num_depths
# return depth prediction for supervision
depths = rearrange(
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w
).squeeze(-1).squeeze(-1)
# print(depths.shape) # [B, V, H, W]
return {
"gaussians": None,
"depths": depths
}
# features [BV, C, H, W]
features = self.feature_upsampler(results_dict["features_mono_intermediate"],
cnn_features=results_dict["features_cnn_all_scales"][::-1],
mv_features=results_dict["features_mv"][
0] if self.cfg.num_scales == 1 else results_dict["features_mv"][::-1]
)
# match prob from softmax
# [BV, D, H, W] in feature resolution
match_prob = results_dict['match_probs'][-1]
match_prob = torch.max(match_prob, dim=1, keepdim=True)[
0] # [BV, 1, H, W]
match_prob = F.interpolate(
match_prob, size=depth.shape[-2:], mode='nearest')
# unet input [BV, C, H, W]
concat = torch.cat((
rearrange(context["image"], "b v c h w -> (b v) c h w"),
rearrange(depth, "b v h w -> (b v) () h w"),
match_prob,
features,
), dim=1)
# [BV, C, H, W]
out = self.gaussian_regressor(concat)
concat = [out,
rearrange(context["image"],
"b v c h w -> (b v) c h w"),
features,
match_prob]
# [BV, C, H, W]
out = torch.cat(concat, dim=1)
gaussians = self.gaussian_head(out) # [BV, C, H, W]
# [B, V, C, H, W]
gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v)
# [B, V, H*W, 1, 1]
depths = rearrange(depth, "b v h w -> b v (h w) () ()")
# [B, V, H*W, 1, 1]
densities = rearrange(
match_prob, "(b v) c h w -> b v (c h w) () ()", b=b, v=v)
# [B, V, H*W, 37]
raw_gaussians = rearrange(
gaussians, "b v c h w -> b v (h w) c")
if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1:
# supervise all the intermediate depth predictions
num_depths = len(depth_preds)
# [B, V, H*W, 1, 1]
intermediate_depths = torch.cat(
depth_preds[:(num_depths - 1)], dim=0)
intermediate_depths = rearrange(
intermediate_depths, "b v h w -> b v (h w) () ()")
# concat in the batch dim
depths = torch.cat((intermediate_depths, depths), dim=0)
# shared color head [2B, V, H×W, C]
densities = torch.cat([densities] * num_depths, dim=0)
raw_gaussians = torch.cat(
[raw_gaussians] * num_depths, dim=0)
b *= num_depths
# [B, V, H*W, 1, 1]
opacities = raw_gaussians[..., :1].sigmoid().unsqueeze(-1)
raw_gaussians = raw_gaussians[..., 1:]
# Convert the features and depths into Gaussians.
xy_ray, _ = sample_image_grid((h, w), device) #(x,y)
xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy")
gaussians = rearrange(
raw_gaussians,
"... (srf c) -> ... srf c",
srf=self.cfg.num_surfaces,
)
offset_xy = gaussians[..., :2].sigmoid()
pixel_size = 1 / \
torch.tensor((w, h), dtype=torch.float32, device=device)
xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size
sh_input_images = context["image"]
if self.cfg.supervise_intermediate_depth and len(depth_preds) > 1:
context_extrinsics = torch.cat(
[context["extrinsics"]] * len(depth_preds), dim=0)
context_intrinsics = torch.cat(
[context["intrinsics"]] * len(depth_preds), dim=0)
gaussians = self.gaussian_adapter.forward(
rearrange(context_extrinsics, "b v i j -> b v () () () i j"),
rearrange(context_intrinsics, "b v i j -> b v () () () i j"),
rearrange(xy_ray, "b v r srf xy -> b v r srf () xy"),
depths,
opacities,
rearrange(
gaussians[..., 2:],
"b v r srf c -> b v r srf () c",
),
(h, w),
input_images=sh_input_images.repeat(
len(depth_preds), 1, 1, 1, 1) if self.cfg.init_sh_input_img else None,
)
else:
gaussians = self.gaussian_adapter.forward(
rearrange(context["extrinsics"],
"b v i j -> b v () () () i j"),
rearrange(context["intrinsics"],
"b v i j -> b v () () () i j"),
rearrange(xy_ray, "b v r srf xy -> b v r srf () xy"),
depths,
opacities,
rearrange(
gaussians[..., 2:],
"b v r srf c -> b v r srf () c",
),
(h, w),
input_images=sh_input_images if self.cfg.init_sh_input_img else None,
)
# Dump visualizations if needed.
if visualization_dump is not None:
visualization_dump["depth"] = rearrange(
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w
)
visualization_dump["scales"] = rearrange(
gaussians.scales, "b v r srf spp xyz -> b (v r srf spp) xyz"
)
visualization_dump["rotations"] = rearrange(
gaussians.rotations, "b v r srf spp xyzw -> b (v r srf spp) xyzw"
)
#保存点云
from pathlib import Path
all_points = rearrange(gaussians.means[0], "v r srf spp xyz -> (v r srf spp) xyz") # [B*V*N, 3]
# save_point_cloud_to_ply(all_points, Path("/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/project_point_cloud"), "all_points")
gaussians = Gaussians(
rearrange(
gaussians.means, #[2, 6, 114688, 1, 1, 3]
"b v r srf spp xyz -> b (v r srf spp) xyz", #[2, 688128, 3]
),
rearrange(
gaussians.covariances, #[2, 6, 114688, 1, 1, 3, 3]
"b v r srf spp i j -> b (v r srf spp) i j", #[2, 688128, 3, 3]
),
rearrange(
gaussians.harmonics, #[2, 6, 114688, 1, 1, 3, 9]
"b v r srf spp c d_sh -> b (v r srf spp) c d_sh", #[2, 688128, 3, 9]
),
rearrange(
gaussians.opacities, #[2, 6, 114688, 1, 1]
"b v r srf spp -> b (v r srf spp)", #[2, 688128]
),
)
if self.cfg.return_depth:
# return depth prediction for supervision
depths = rearrange(
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w
).squeeze(-1).squeeze(-1)
# print(depths.shape) # [B, V, H, W] [2, 6, 256, 448]
return {
"gaussians": gaussians,
"depths": depths
}
return gaussians
def get_data_shim(self) -> DataShim:
def data_shim(batch: BatchedExample) -> BatchedExample:
batch = apply_patch_shim(
batch,
patch_size=self.cfg.shim_patch_size
* self.cfg.downscale_factor,
)
return batch
return data_shim
@property
def sampler(self):
return None