linhaotong
update
b9f87ab
# flake8: noqa: F722
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from types import SimpleNamespace
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from einops import einsum
def as_homogeneous(ext):
"""
Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix.
Supports torch.Tensor or np.ndarray.
"""
if isinstance(ext, torch.Tensor):
# If already in homogeneous form
if ext.shape[-2:] == (4, 4):
return ext
elif ext.shape[-2:] == (3, 4):
# Create a new homogeneous matrix
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
else:
raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}")
elif isinstance(ext, np.ndarray):
if ext.shape[-2:] == (4, 4):
return ext
elif ext.shape[-2:] == (3, 4):
ones = np.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return np.concatenate([ext, ones], axis=-2)
else:
raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}")
else:
raise TypeError("Input must be a torch.Tensor or np.ndarray.")
@torch.jit.script
def affine_inverse(A: torch.Tensor):
R = A[..., :3, :3] # ..., 3, 3
T = A[..., :3, 3:] # ..., 3, 1
P = A[..., 3:, :] # ..., 1, 4
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
def transpose_last_two_axes(arr):
"""
for np < 2
"""
if arr.ndim < 2:
return arr
axes = list(range(arr.ndim))
# swap the last two
axes[-2], axes[-1] = axes[-1], axes[-2]
return arr.transpose(axes)
def affine_inverse_np(A: np.array):
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return np.concatenate(
[
np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1),
P,
],
axis=-2,
)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Quaternion Order: XYZW or say ijkr, scalar-last
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
i, j, k, r = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Convert from rijk to ijkr
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def sample_image_grid(
shape: tuple[int, ...],
device: torch.device = torch.device("cpu"),
) -> tuple[
torch.Tensor, # float coordinates (xy indexing), "*shape dim"
torch.Tensor, # integer indices (ij indexing), "*shape dim"
]:
"""Get normalized (range 0 to 1) coordinates and integer indices for an image."""
# Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
# (row, col) coordinate.
indices = [torch.arange(length, device=device) for length in shape]
stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
# Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
# each entry is an (x, y) coordinate.
coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
coordinates = reversed(coordinates)
coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
return coordinates, stacked_indices
def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
"""Convert batched points (xyz) to (xyz1)."""
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
"""Convert batched vectors (xyz) to (xyz0)."""
return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
def transform_rigid(
homogeneous_coordinates: torch.Tensor, # "*#batch dim"
transformation: torch.Tensor, # "*#batch dim dim"
) -> torch.Tensor: # "*batch dim"
"""Apply a rigid-body transformation to points or vectors."""
return einsum(
transformation,
homogeneous_coordinates.to(transformation.dtype),
"... i j, ... j -> ... i",
)
def transform_cam2world(
homogeneous_coordinates: torch.Tensor, # "*#batch dim"
extrinsics: torch.Tensor, # "*#batch dim dim"
) -> torch.Tensor: # "*batch dim"
"""Transform points from 3D camera coordinates to 3D world coordinates."""
return transform_rigid(homogeneous_coordinates, extrinsics)
def unproject(
coordinates: torch.Tensor, # "*#batch dim"
z: torch.Tensor, # "*#batch"
intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
) -> torch.Tensor: # "*batch dim+1"
"""Unproject 2D camera coordinates with the given Z values."""
# Apply the inverse intrinsics to the coordinates.
coordinates = homogenize_points(coordinates)
ray_directions = einsum(
intrinsics.float().inverse().to(intrinsics),
coordinates.to(intrinsics.dtype),
"... i j, ... j -> ... i",
)
# Apply the supplied depth values.
return ray_directions * z[..., None]
def get_world_rays(
coordinates: torch.Tensor, # "*#batch dim"
extrinsics: torch.Tensor, # "*#batch dim+2 dim+2"
intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
) -> tuple[
torch.Tensor, # origins, "*batch dim+1"
torch.Tensor, # directions, "*batch dim+1"
]:
# Get camera-space ray directions.
directions = unproject(
coordinates,
torch.ones_like(coordinates[..., 0]),
intrinsics,
)
directions = directions / directions.norm(dim=-1, keepdim=True)
# Transform ray directions to world coordinates.
directions = homogenize_vectors(directions)
directions = transform_cam2world(directions, extrinsics)[..., :-1]
# Tile the ray origins to have the same shape as the ray directions.
origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
return origins, directions
def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2"
intrinsics_inv = intrinsics.float().inverse().to(intrinsics)
def process_vector(vector):
vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device)
vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
return vector / vector.norm(dim=-1, keepdim=True)
left = process_vector([0, 0.5, 1])
right = process_vector([1, 0.5, 1])
top = process_vector([0.5, 0, 1])
bottom = process_vector([0.5, 1, 1])
fov_x = (left * right).sum(dim=-1).acos()
fov_y = (top * bottom).sum(dim=-1).acos()
return torch.stack((fov_x, fov_y), dim=-1)
def map_pdf_to_opacity(
pdf: torch.Tensor, # " *batch"
global_step: int = 0,
opacity_mapping: Optional[dict] = None,
) -> torch.Tensor: # " *batch"
# https://www.desmos.com/calculator/opvwti3ba9
# Figure out the exponent.
if opacity_mapping is not None:
cfg = SimpleNamespace(**opacity_mapping)
x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial)
else:
x = 0.0
exponent = 2**x
# Map the probability density to an opacity.
return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent))