linhaotong
update
b9f87ab
raw
history blame
3.24 kB
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import isqrt
import torch
from einops import einsum
try:
from e3nn.o3 import matrix_to_angles, wigner_D
except ImportError:
from depth_anything_3.utils.logger import logger
logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff")
def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor:
if M.shape[-2:] != (3, 3):
raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).")
# 1. Compute SVD
U, S, Vh = torch.linalg.svd(M)
V = Vh.mH
# 2. Handle reflection case (det = -1)
det_U = torch.det(U)
det_V = torch.det(V)
is_reflection = (det_U * det_V) < 0
correction_sign = torch.where(
is_reflection[..., None],
torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype),
torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype),
)
correction_matrix = torch.diag_embed(correction_sign)
U_corrected = U @ correction_matrix
R_so3_initial = U_corrected @ V.transpose(-2, -1)
# 3. Explicitly ensure determinant is 1 (or extremely close)
current_det = torch.det(R_so3_initial)
det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None]
R_so3_final = R_so3_initial * det_correction_factor
return R_so3_final
def rotate_sh(
sh_coefficients: torch.Tensor, # "*#batch n"
rotations: torch.Tensor, # "*#batch 3 3"
) -> torch.Tensor: # "*batch n"
# https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653
device = sh_coefficients.device
dtype = sh_coefficients.dtype
*_, n = sh_coefficients.shape
with torch.autocast(device_type=rotations.device.type, enabled=False):
rotations_float32 = rotations.to(torch.float32)
# switch axes: yzx -> xyz
P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32)
permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P
# ensure rotation has det == 1 in float32 type
permuted_rotations_so3 = project_to_so3_strict(permuted_rotations)
alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3)
result = []
for degree in range(isqrt(n)):
with torch.device(device):
sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
sh_rotated = einsum(
sh_rotations,
sh_coefficients[..., degree**2 : (degree + 1) ** 2],
"... i j, ... j -> ... i",
)
result.append(sh_rotated)
return torch.cat(result, dim=-1)