viewtoken-harmon-demo / src /models /rotation_utils.py
XinxuanLu's picture
Initial demo
becf13a verified
"""
Rotation representation utilities for viewpoint prediction.
Implements 6D rotation representation from:
"On the Continuity of Rotation Representations in Neural Networks"
Zhou et al., CVPR 2019
The 6D representation uses the first two columns of a rotation matrix,
which is continuous and can be converted to a valid rotation matrix
via Gram-Schmidt orthonormalization.
"""
import torch
import numpy as np
def normalize_vector(v, dim=-1, eps=1e-6):
"""Normalize vector to unit length with strong numerical protection.
Args:
v: Input vector of shape (..., D)
dim: Dimension to normalize along
eps: Small epsilon for numerical stability (increased to 1e-6 for float32)
Returns:
Normalized vector of same shape
"""
norm = torch.norm(v, dim=dim, keepdim=True)
norm = torch.clamp(norm, min=eps) # Prevent division by near-zero
return v / norm
def rotation_6d_to_matrix(d6):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Uses Gram-Schmidt process to orthonormalize the first two columns
and compute the third via cross product, with numerical stability checks.
Args:
d6: 6D rotation representation, shape (..., 6)
First 3 values = first column
Last 3 values = second column (not necessarily orthogonal)
Returns:
Rotation matrix of shape (..., 3, 3)
"""
# Extract first two columns
a1 = d6[..., :3] # First column
a2_unnormalized = d6[..., 3:] # Second column (not orthogonal yet)
# Normalize first column with strong protection
b1 = normalize_vector(a1, eps=1e-6)
# Gram-Schmidt: make a2 orthogonal to b1
b2_unnormalized = a2_unnormalized - (b1 * a2_unnormalized).sum(dim=-1, keepdim=True) * b1
# Add minimum magnitude enforcement before normalizing
b2_norm = torch.norm(b2_unnormalized, dim=-1, keepdim=True)
b2_norm = torch.clamp(b2_norm, min=1e-6)
b2 = b2_unnormalized / b2_norm
# Compute third column via cross product
b3 = torch.cross(b1, b2, dim=-1)
# Stack into rotation matrix
rotation_matrix = torch.stack([b1, b2, b3], dim=-1)
return rotation_matrix
def matrix_to_rotation_6d(matrix):
"""Convert 3x3 rotation matrix to 6D representation.
Simply extracts the first two columns and concatenates them.
Args:
matrix: Rotation matrix of shape (..., 3, 3)
Returns:
6D representation of shape (..., 6)
"""
# Extract first two columns
# matrix[..., :, :2] gives shape (..., 3, 2)
# We want to flatten column-wise: [col1[0], col1[1], col1[2], col2[0], col2[1], col2[2]]
col1 = matrix[..., :, 0] # First column, shape (..., 3)
col2 = matrix[..., :, 1] # Second column, shape (..., 3)
d6 = torch.cat([col1, col2], dim=-1) # Concatenate, shape (..., 6)
return d6
def rotation_matrix_to_axis_angle(matrix):
"""Convert rotation matrix to axis-angle representation.
Args:
matrix: Rotation matrix of shape (..., 3, 3)
Returns:
Angle in radians, shape (...)
"""
# Compute rotation angle from trace
# trace(R) = 1 + 2*cos(theta)
trace = matrix[..., 0, 0] + matrix[..., 1, 1] + matrix[..., 2, 2]
# Clamp for numerical stability
cos_angle = (trace - 1.0) / 2.0
cos_angle = torch.clamp(cos_angle, -1.0, 1.0)
angle = torch.acos(cos_angle)
return angle
def geodesic_loss(pred_matrix, gt_matrix):
"""Compute geodesic loss on SO(3) manifold.
Measures the angular distance between two rotation matrices.
Loss = ||log(R_pred^T @ R_gt)||_F
For efficiency, we use the trace formula:
angle = arccos((trace(R_pred^T @ R_gt) - 1) / 2)
Args:
pred_matrix: Predicted rotation matrix, shape (B, 3, 3)
gt_matrix: Ground truth rotation matrix, shape (B, 3, 3)
Returns:
Mean geodesic loss (scalar)
"""
# Compute relative rotation: R_pred^T @ R_gt
relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix)
# Compute angle from trace
angle = rotation_matrix_to_axis_angle(relative_rotation)
# Return mean squared angle (in radians)
return (angle ** 2).mean()
def rotation_angle_error(pred_matrix, gt_matrix):
"""Compute rotation error in degrees for evaluation.
Args:
pred_matrix: Predicted rotation matrix, shape (B, 3, 3)
gt_matrix: Ground truth rotation matrix, shape (B, 3, 3)
Returns:
Rotation error in degrees, shape (B,)
"""
# Compute relative rotation
relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix)
# Compute angle
angle_rad = rotation_matrix_to_axis_angle(relative_rotation)
# Convert to degrees
angle_deg = angle_rad * 180.0 / np.pi
return angle_deg
def batch_rotation_6d_to_matrix(d6_batch):
"""Batched version of rotation_6d_to_matrix for efficiency.
Args:
d6_batch: 6D rotation representation, shape (B, 6)
Returns:
Rotation matrices of shape (B, 3, 3)
"""
return rotation_6d_to_matrix(d6_batch)
def test_rotation_conversion():
"""Test rotation conversion functions."""
print("Testing rotation conversion...")
# Test 1: Create valid rotation matrices using QR decomposition
# This ensures we start with valid rotation matrices
batch_size = 5
random_matrices = torch.randn(batch_size, 3, 3)
# QR decomposition gives us orthogonal matrices
Q, R_diag = torch.linalg.qr(random_matrices)
# Ensure determinant is +1 (proper rotation, not reflection)
det = torch.det(Q)
Q = Q * det.unsqueeze(-1).unsqueeze(-1) # Flip if det=-1
R_gt = Q # Valid rotation matrices
# Verify R_gt is valid
det_gt = torch.det(R_gt)
I = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3)
ortho_error_gt = torch.norm(torch.bmm(R_gt.transpose(-2, -1), R_gt) - I, dim=(1, 2))
print(f"Ground truth check - Det: {det_gt.mean():.6f}, Ortho error: {ortho_error_gt.mean():.6f}")
# Convert to 6D and back
d6 = matrix_to_rotation_6d(R_gt)
R_reconstructed = rotation_6d_to_matrix(d6)
# Check reconstruction error
error = torch.norm(R_gt - R_reconstructed, dim=(1, 2))
print(f"Reconstruction error: {error.mean().item():.6f} (should be ~0)")
# Check orthogonality of reconstructed matrix
ortho_error = torch.norm(torch.bmm(R_reconstructed.transpose(-2, -1), R_reconstructed) - I, dim=(1, 2))
print(f"Orthogonality error: {ortho_error.mean().item():.6f} (should be ~0)")
# Check determinant (should be 1)
det = torch.det(R_reconstructed)
print(f"Determinant: {det.mean().item():.6f} (should be ~1)")
# Test 2: Test that 6D conversion produces valid rotations
random_6d = torch.randn(batch_size, 6)
R_from_6d = rotation_6d_to_matrix(random_6d)
det_from_6d = torch.det(R_from_6d)
ortho_error_from_6d = torch.norm(torch.bmm(R_from_6d.transpose(-2, -1), R_from_6d) - I, dim=(1, 2))
print(f"\n6D->matrix check - Det: {det_from_6d.mean():.6f}, Ortho error: {ortho_error_from_6d.mean():.6f}")
# Overall test result
test_passed = (
error.mean() < 1e-5 and
ortho_error.mean() < 1e-5 and
abs(det.mean() - 1.0) < 1e-5 and
ortho_error_from_6d.mean() < 1e-5 and
abs(det_from_6d.mean() - 1.0) < 1e-5
)
print("\nTests passed!" if test_passed else "Tests FAILED!")
if __name__ == "__main__":
test_rotation_conversion()