physctrl / src /utils /interpolate.py
chenwang's picture
update
4724018
import torch
import torch
import torch.nn.functional as F
def get_rigid_transform(A, B):
"""
Estimate the rigid body transformation between two sets of 3D points.
A and B are Nx3 matrices where each row is a 3D point.
Returns a rotation matrix R and translation vector t.
Args:
A, B: [batch, N, 3] matrix of 3D points
Outputs:
R, t: [batch, 3, 3/1]
target = R @ source (source shape [3, 1]) + t
"""
assert A.shape == B.shape, "Input matrices must have the same shape"
assert A.shape[-1] == 3, "Input matrices must have 3 columns (x, y, z coordinates)"
# Compute centroids. [..., 1, 3]
centroid_A = torch.mean(A, dim=-2, keepdim=True)
centroid_B = torch.mean(B, dim=-2, keepdim=True)
# Center the point sets
A_centered = A - centroid_A
B_centered = B - centroid_B
# Compute the cross-covariance matrix. [..., 3, 3]
H = A_centered.transpose(-2, -1) @ B_centered
# Compute the Singular Value Decomposition. Along last two dimensions
U, S, Vt = torch.linalg.svd(H)
# Compute the rotation matrix
R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)
# Ensure a right-handed coordinate system
flip_mask = (torch.det(R) < 0) * -2.0 + 1.0
# Vt[:, 2, :] *= flip_mask[..., None]
# [N] => [N, 3]
pad_flip_mask = torch.stack(
[torch.ones_like(flip_mask), torch.ones_like(flip_mask), flip_mask], dim=-1
)
Vt = Vt * pad_flip_mask[..., None]
# Compute the rotation matrix
R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)
# print(R.shape, centroid_A.shape, centroid_B.shape, flip_mask.shape)
# Compute the translation
t = centroid_B - (R @ centroid_A.transpose(-2, -1)).transpose(-2, -1)
t = t.transpose(-2, -1)
return R, t
def _test_rigid_transform():
# Example usage:
A = torch.tensor([[1, 2, 3], [4, 5, 6], [9, 8, 10], [10, -5, 1]]) * 1.0
R_synthesized = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) * 1.0
# init a random rotation matrix:
B = (R_synthesized @ A.T).T + 2.0 # Just an example offset
R, t = get_rigid_transform(A[None, ...], B[None, ...])
print("Rotation matrix R:")
print(R)
print("\nTranslation vector t:")
print(t)
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
from pytorch3d. Based on trace_method like: https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L205
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
def quternion_to_matrix(r):
norm = torch.sqrt(
r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
)
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device="cuda")
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y * y + z * z)
R[:, 0, 1] = 2 * (x * y - r * z)
R[:, 0, 2] = 2 * (x * z + r * y)
R[:, 1, 0] = 2 * (x * y + r * z)
R[:, 1, 1] = 1 - 2 * (x * x + z * z)
R[:, 1, 2] = 2 * (y * z - r * x)
R[:, 2, 0] = 2 * (x * z - r * y)
R[:, 2, 1] = 2 * (y * z + r * x)
R[:, 2, 2] = 1 - 2 * (x * x + y * y)
return R
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
from Pytorch3d
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
From pytorch3d
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
ret = torch.stack((ow, ox, oy, oz), -1)
ret = standardize_quaternion(ret)
return ret
def _test_matrix_to_quaternion():
# init a random batch of quaternion
r = torch.randn((10, 4)).cuda()
norm = torch.sqrt(
r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
)
q = r / norm[:, None]
q = standardize_quaternion(q)
R = quternion_to_matrix(q)
I_rec = R @ R.transpose(-2, -1)
I_rec_error = torch.abs(I_rec - torch.eye(3, device="cuda")[None, ...]).max()
q_recovered = matrix_to_quaternion(R)
norm_ = torch.linalg.norm(q_recovered, dim=-1)
q_recovered = q_recovered / norm_[..., None]
q_recovered = standardize_quaternion(q_recovered)
print(q_recovered.shape, q.shape, R.shape)
rec = (q - q_recovered).abs().max()
print("rotation to I error:", I_rec_error, "quant rec error: ", rec)
def _test_matrix_to_quaternion_2():
R = (
torch.tensor(
[[[1, 0, 0], [0, -1, 0], [0, 0, -1]], [[1, 0, 0], [0, 0, 1], [0, -1, 0]]]
)
* 1.0
)
q_rec = matrix_to_quaternion(R.transpose(-2, -1))
R_rec = quternion_to_matrix(q_rec)
print(R_rec)
def interpolate_points_w_R(
query_points, query_rotation, drive_origin_pts, drive_displacement, top_k_index
):
"""
Args:
query_points: [n, 3]
drive_origin_pts: [m, 3]
drive_displacement: [m, 3]
top_k_index: [n, top_k] < m
Or directly call: apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):
Args:
origin_points: (N_r, 3)
offsets: (N_r, 3)
in rendering
"""
# [n, topk, 3]
top_k_disp = drive_displacement[top_k_index]
source_points = drive_origin_pts[top_k_index]
R, t = get_rigid_transform(source_points, source_points + top_k_disp)
avg_offsets = top_k_disp.mean(dim=1)
ret_points = query_points + avg_offsets
new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)
return ret_points, new_rotation
def interpolate_points(
query_points, query_rotation, drive_origin_pts, drive_current_points, top_k_index
):
source_points = drive_origin_pts[top_k_index] # [n, topk, 3]
target_points = drive_current_points[top_k_index] # [n, topk, 3]
disp = target_points - source_points
avg_offsets = disp.mean(dim=1) # [n, 3]
ret_points = query_points + avg_offsets # [n, 3]
# ret_points = target_points.mean(dim=1) # [n, 3]
R, t = get_rigid_transform(source_points, target_points)
new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)
return ret_points, new_rotation