from dataclasses import dataclass import torch from einops import einsum, rearrange from jaxtyping import Float from torch import Tensor, nn import torch.nn.functional as F from ....geometry.projection import get_world_rays from ....misc.sh_rotation import rotate_sh from .gaussians import build_covariance from .voxel_feature import project_features_to_3d, project_features_to_voxel, adapte_project_features_to_3d from .me_fea import project_features_to_me from typing import Tuple, Optional from ....geometry.projection import sample_voxel_grid from ....test.export_ply import save_point_cloud_to_ply @dataclass class Gaussians: means: Float[Tensor, "*batch 3"] covariances: Float[Tensor, "*batch 3 3"] scales: Float[Tensor, "*batch 3"] rotations: Float[Tensor, "*batch 4"] harmonics: Float[Tensor, "*batch 3 _"] opacities: Float[Tensor, " *batch"] @dataclass class GaussianAdapterCfg: gaussian_scale_min: float gaussian_scale_max: float sh_degree: int class GaussianAdapter_depth(nn.Module): cfg: GaussianAdapterCfg def __init__(self, cfg: GaussianAdapterCfg): super().__init__() self.cfg = cfg # Create a mask for the spherical harmonics coefficients. This ensures that at # initialization, the coefficients are biased towards having a large DC # component and small view-dependent components. self.register_buffer( "sh_mask", torch.ones((self.d_sh,), dtype=torch.float32), persistent=False, ) for degree in range(1, self.cfg.sh_degree + 1): # 为不同阶数的球谐系数设置不同的权重(高阶系数权重更低) self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree def forward( self, extrinsics: Tensor, intrinsics: Tensor | None, opacities: Tensor, raw_gaussians: Tensor, #[1, 1, N, 37] input_images: Tensor | None = None, depth : Tensor | None = None, coordidate: Optional[Tensor] = None, points: Optional[Tensor] = None, voxel_resolution: float = 0.01, eps: float = 1e-8, ) : #-> Gaussians # 获取批处理维度 batch_dims = extrinsics.shape[:-2] # 提取 b 和 v b, v = batch_dims # 分割高斯参数 offset_xyz,scales, rotations, sh = raw_gaussians.split((3,3, 4, 3 * self.d_sh), dim=-1) #[1, 1, N,1, 1,c] # scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) #[1, 1, N,1, 1,c] #对scale限制 scales = torch.clamp(F.softplus(scales - 4.), min=self.cfg.gaussian_scale_min, max=self.cfg.gaussian_scale_max, ) # Normalize the quaternion features to yield a valid quaternion. rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) #重排 SH sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) # [1, 1, 256000, 1, 1, 3, 9] sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask if input_images is not None : voxel_color, aggregated_points, counts = project_features_to_me( intrinsics = intrinsics, extrinsics = extrinsics, out = input_images, depth = depth, voxel_resolution = voxel_resolution, b=b,v=v ) # if torch.equal(coordidate, voxel_color.C): if coordidate.shape == voxel_color.C.shape: colors = voxel_color.F # [B*V*N, C] # 3. 将RGB转换为球谐系数的0阶项 sh0 = RGB2SH(colors) # 形状变为 [N, 3] sh0_expanded = sh0.view(1, 1, -1, 1, 1, 3) # [1,1,N,1,1,3] sh[..., 0] = sh0_expanded # 添加d_sh维度 # Create world-space covariance matrices. covariances = build_covariance(scales, rotations) #[1, 1, 256000, 1, 1, 3, 3] # # 应用逆变换 - 还原原始坐标 # xyz = grid *voxel_resolution # [N,3] 世界坐标 xyz = points xyz = rearrange(xyz, "n c -> 1 1 n () () c") # [1,1,N,1,1,3] # 应用偏移量 offset_xyz = offset_xyz.sigmoid() # 对补偿值归一化 [1,1,N,1,1, 3] offset_world = (offset_xyz - 0.5) *voxel_resolution*3 # [1,1,N,1,1, 3] # 最终高斯点位置 [N, 3] means = xyz + offset_world # [1,1,N, 1,1,3] means = xyz return Gaussians( means=means, covariances=covariances, harmonics=sh, opacities=opacities, # NOTE: These aren't yet rotated into world space, but they're only used for # exporting Gaussians to ply files. This needs to be fixed... scales=scales, rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), ) def get_scale_multiplier( self, intrinsics: Float[Tensor, "*#batch 3 3"], pixel_size: Float[Tensor, "*#batch 2"], multiplier: float = 0.1, ) -> Float[Tensor, " *batch"]: xy_multipliers = multiplier * einsum( intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i", ) return xy_multipliers.sum(dim=-1) @property def d_sh(self) -> int: return (self.cfg.sh_degree + 1) ** 2 @property def d_in(self) -> int: return 7 + 3 * self.d_sh def RGB2SH(rgb): C0 = 0.28209479177387814 return (rgb - 0.5) / C0