Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn.functional as F | |
| def extri_intri_to_pose_encoding( | |
| extrinsics, | |
| intrinsics, | |
| image_size_hw=None, | |
| ): | |
| """Convert camera extrinsics and intrinsics to a compact pose encoding.""" | |
| # extrinsics: BxSx3x4 | |
| # intrinsics: BxSx3x3 | |
| R = extrinsics[:, :, :3, :3] # BxSx3x3 | |
| T = extrinsics[:, :, :3, 3] # BxSx3 | |
| quat = mat_to_quat(R) | |
| # Note the order of h and w here | |
| H, W = image_size_hw | |
| fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) | |
| fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) | |
| pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() | |
| return pose_encoding | |
| def pose_encoding_to_extri_intri( | |
| pose_encoding, | |
| image_size_hw=None, | |
| ): | |
| """Convert a pose encoding back to camera extrinsics and intrinsics.""" | |
| T = pose_encoding[..., :3] | |
| quat = pose_encoding[..., 3:7] | |
| fov_h = pose_encoding[..., 7] | |
| fov_w = pose_encoding[..., 8] | |
| R = quat_to_mat(quat) | |
| extrinsics = torch.cat([R, T[..., None]], dim=-1) | |
| H, W = image_size_hw | |
| fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) | |
| fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) | |
| intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) | |
| intrinsics[..., 0, 0] = fx | |
| intrinsics[..., 1, 1] = fy | |
| intrinsics[..., 0, 2] = W / 2 | |
| intrinsics[..., 1, 2] = H / 2 | |
| intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 | |
| return extrinsics, intrinsics | |
| def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Quaternion Order: XYZW or say ijkr, scalar-last | |
| Convert rotations given as quaternions to rotation matrices. | |
| Args: | |
| quaternions: quaternions with real part last, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| i, j, k, r = torch.unbind(quaternions, -1) | |
| two_s = 2.0 / (quaternions * quaternions).sum(-1) | |
| o = torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| -1, | |
| ) | |
| return o.reshape(quaternions.shape[:-1] + (3, 3)) | |
| def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert rotations given as rotation matrices to quaternions. | |
| Args: | |
| matrix: Rotation matrices as tensor of shape (..., 3, 3). | |
| Returns: | |
| quaternions with real part last, as tensor of shape (..., 4). | |
| Quaternion Order: XYZW or say ijkr, scalar-last | |
| """ | |
| if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
| raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
| batch_dim = matrix.shape[:-2] | |
| m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( | |
| matrix.reshape(batch_dim + (9,)), dim=-1 | |
| ) | |
| q_abs = _sqrt_positive_part( | |
| torch.stack( | |
| [ | |
| 1.0 + m00 + m11 + m22, | |
| 1.0 + m00 - m11 - m22, | |
| 1.0 - m00 + m11 - m22, | |
| 1.0 - m00 - m11 + m22, | |
| ], | |
| dim=-1, | |
| ) | |
| ) | |
| quat_by_rijk = torch.stack( | |
| [ | |
| torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), | |
| torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), | |
| torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), | |
| torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), | |
| ], | |
| dim=-2, | |
| ) | |
| flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) | |
| quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) | |
| out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( | |
| batch_dim + (4,) | |
| ) | |
| out = out[..., [1, 2, 3, 0]] | |
| out = standardize_quaternion(out) | |
| return out | |
| def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Returns torch.sqrt(torch.max(0, x)) | |
| but with a zero subgradient where x is 0. | |
| """ | |
| ret = torch.zeros_like(x) | |
| positive_mask = x > 0 | |
| if torch.is_grad_enabled(): | |
| ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
| else: | |
| ret = torch.where(positive_mask, torch.sqrt(x), ret) | |
| return ret | |
| def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert a unit quaternion to a standard form: one in which the real | |
| part is non negative. | |
| Args: | |
| quaternions: Quaternions with real part last, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Standardized quaternions as tensor of shape (..., 4). | |
| """ | |
| return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) | |
| def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w): | |
| # cam_quat_xyzw: (b, n, 4) in xyzw | |
| # c2w: (b, n, 4, 4) | |
| b, n = cam_quat_xyzw.shape[:2] | |
| # 1. xyzw -> wxyz | |
| cam_quat_wxyz = torch.cat( | |
| [ | |
| cam_quat_xyzw[..., 3:4], # w | |
| cam_quat_xyzw[..., 0:1], # x | |
| cam_quat_xyzw[..., 1:2], # y | |
| cam_quat_xyzw[..., 2:3], # z | |
| ], | |
| dim=-1, | |
| ) | |
| # 2. Quaternion to matrix | |
| cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4) | |
| rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3) | |
| # 3. Transform to world space | |
| rotmat_c2w = c2w[..., :3, :3] | |
| rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam) | |
| # 4. Matrix to quaternion (wxyz) | |
| rotmat_world_flat = rotmat_world.reshape(-1, 3, 3) | |
| world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat) | |
| world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4) | |
| return world_quat_wxyz | |