# flake8: noqa: F722 # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from types import SimpleNamespace from typing import Optional import numpy as np import torch import torch.nn.functional as F from einops import einsum def as_homogeneous(ext): """ Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix. Supports torch.Tensor or np.ndarray. """ if isinstance(ext, torch.Tensor): # If already in homogeneous form if ext.shape[-2:] == (4, 4): return ext elif ext.shape[-2:] == (3, 4): # Create a new homogeneous matrix ones = torch.zeros_like(ext[..., :1, :4]) ones[..., 0, 3] = 1.0 return torch.cat([ext, ones], dim=-2) else: raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}") elif isinstance(ext, np.ndarray): if ext.shape[-2:] == (4, 4): return ext elif ext.shape[-2:] == (3, 4): ones = np.zeros_like(ext[..., :1, :4]) ones[..., 0, 3] = 1.0 return np.concatenate([ext, ones], axis=-2) else: raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}") else: raise TypeError("Input must be a torch.Tensor or np.ndarray.") @torch.jit.script def affine_inverse(A: torch.Tensor): R = A[..., :3, :3] # ..., 3, 3 T = A[..., :3, 3:] # ..., 3, 1 P = A[..., 3:, :] # ..., 1, 4 return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) def transpose_last_two_axes(arr): """ for np < 2 """ if arr.ndim < 2: return arr axes = list(range(arr.ndim)) # swap the last two axes[-2], axes[-1] = axes[-1], axes[-2] return arr.transpose(axes) def affine_inverse_np(A: np.array): R = A[..., :3, :3] T = A[..., :3, 3:] P = A[..., 3:, :] return np.concatenate( [ np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1), P, ], axis=-2, ) def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: """ Quaternion Order: XYZW or say ijkr, scalar-last Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part last, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ i, j, k, r = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part last, as tensor of shape (..., 4). Quaternion Order: XYZW or say ijkr, scalar-last """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( batch_dim + (4,) ) # Convert from rijk to ijkr out = out[..., [1, 2, 3, 0]] out = standardize_quaternion(out) return out def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 if torch.is_grad_enabled(): ret[positive_mask] = torch.sqrt(x[positive_mask]) else: ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part last, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) def sample_image_grid( shape: tuple[int, ...], device: torch.device = torch.device("cpu"), ) -> tuple[ torch.Tensor, # float coordinates (xy indexing), "*shape dim" torch.Tensor, # integer indices (ij indexing), "*shape dim" ]: """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a # (row, col) coordinate. indices = [torch.arange(length, device=device) for length in shape] stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, # each entry is an (x, y) coordinate. coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] coordinates = reversed(coordinates) coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) return coordinates, stacked_indices def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" """Convert batched points (xyz) to (xyz1).""" return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" """Convert batched vectors (xyz) to (xyz0).""" return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) def transform_rigid( homogeneous_coordinates: torch.Tensor, # "*#batch dim" transformation: torch.Tensor, # "*#batch dim dim" ) -> torch.Tensor: # "*batch dim" """Apply a rigid-body transformation to points or vectors.""" return einsum( transformation, homogeneous_coordinates.to(transformation.dtype), "... i j, ... j -> ... i", ) def transform_cam2world( homogeneous_coordinates: torch.Tensor, # "*#batch dim" extrinsics: torch.Tensor, # "*#batch dim dim" ) -> torch.Tensor: # "*batch dim" """Transform points from 3D camera coordinates to 3D world coordinates.""" return transform_rigid(homogeneous_coordinates, extrinsics) def unproject( coordinates: torch.Tensor, # "*#batch dim" z: torch.Tensor, # "*#batch" intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" ) -> torch.Tensor: # "*batch dim+1" """Unproject 2D camera coordinates with the given Z values.""" # Apply the inverse intrinsics to the coordinates. coordinates = homogenize_points(coordinates) ray_directions = einsum( intrinsics.float().inverse().to(intrinsics), coordinates.to(intrinsics.dtype), "... i j, ... j -> ... i", ) # Apply the supplied depth values. return ray_directions * z[..., None] def get_world_rays( coordinates: torch.Tensor, # "*#batch dim" extrinsics: torch.Tensor, # "*#batch dim+2 dim+2" intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" ) -> tuple[ torch.Tensor, # origins, "*batch dim+1" torch.Tensor, # directions, "*batch dim+1" ]: # Get camera-space ray directions. directions = unproject( coordinates, torch.ones_like(coordinates[..., 0]), intrinsics, ) directions = directions / directions.norm(dim=-1, keepdim=True) # Transform ray directions to world coordinates. directions = homogenize_vectors(directions) directions = transform_cam2world(directions, extrinsics)[..., :-1] # Tile the ray origins to have the same shape as the ray directions. origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) return origins, directions def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2" intrinsics_inv = intrinsics.float().inverse().to(intrinsics) def process_vector(vector): vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device) vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") return vector / vector.norm(dim=-1, keepdim=True) left = process_vector([0, 0.5, 1]) right = process_vector([1, 0.5, 1]) top = process_vector([0.5, 0, 1]) bottom = process_vector([0.5, 1, 1]) fov_x = (left * right).sum(dim=-1).acos() fov_y = (top * bottom).sum(dim=-1).acos() return torch.stack((fov_x, fov_y), dim=-1) def map_pdf_to_opacity( pdf: torch.Tensor, # " *batch" global_step: int = 0, opacity_mapping: Optional[dict] = None, ) -> torch.Tensor: # " *batch" # https://www.desmos.com/calculator/opvwti3ba9 # Figure out the exponent. if opacity_mapping is not None: cfg = SimpleNamespace(**opacity_mapping) x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial) else: x = 0.0 exponent = 2**x # Map the probability density to an opacity. return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent))