""" Rotation representation utilities for viewpoint prediction. Implements 6D rotation representation from: "On the Continuity of Rotation Representations in Neural Networks" Zhou et al., CVPR 2019 The 6D representation uses the first two columns of a rotation matrix, which is continuous and can be converted to a valid rotation matrix via Gram-Schmidt orthonormalization. """ import torch import numpy as np def normalize_vector(v, dim=-1, eps=1e-6): """Normalize vector to unit length with strong numerical protection. Args: v: Input vector of shape (..., D) dim: Dimension to normalize along eps: Small epsilon for numerical stability (increased to 1e-6 for float32) Returns: Normalized vector of same shape """ norm = torch.norm(v, dim=dim, keepdim=True) norm = torch.clamp(norm, min=eps) # Prevent division by near-zero return v / norm def rotation_6d_to_matrix(d6): """Convert 6D rotation representation to 3x3 rotation matrix. Uses Gram-Schmidt process to orthonormalize the first two columns and compute the third via cross product, with numerical stability checks. Args: d6: 6D rotation representation, shape (..., 6) First 3 values = first column Last 3 values = second column (not necessarily orthogonal) Returns: Rotation matrix of shape (..., 3, 3) """ # Extract first two columns a1 = d6[..., :3] # First column a2_unnormalized = d6[..., 3:] # Second column (not orthogonal yet) # Normalize first column with strong protection b1 = normalize_vector(a1, eps=1e-6) # Gram-Schmidt: make a2 orthogonal to b1 b2_unnormalized = a2_unnormalized - (b1 * a2_unnormalized).sum(dim=-1, keepdim=True) * b1 # Add minimum magnitude enforcement before normalizing b2_norm = torch.norm(b2_unnormalized, dim=-1, keepdim=True) b2_norm = torch.clamp(b2_norm, min=1e-6) b2 = b2_unnormalized / b2_norm # Compute third column via cross product b3 = torch.cross(b1, b2, dim=-1) # Stack into rotation matrix rotation_matrix = torch.stack([b1, b2, b3], dim=-1) return rotation_matrix def matrix_to_rotation_6d(matrix): """Convert 3x3 rotation matrix to 6D representation. Simply extracts the first two columns and concatenates them. Args: matrix: Rotation matrix of shape (..., 3, 3) Returns: 6D representation of shape (..., 6) """ # Extract first two columns # matrix[..., :, :2] gives shape (..., 3, 2) # We want to flatten column-wise: [col1[0], col1[1], col1[2], col2[0], col2[1], col2[2]] col1 = matrix[..., :, 0] # First column, shape (..., 3) col2 = matrix[..., :, 1] # Second column, shape (..., 3) d6 = torch.cat([col1, col2], dim=-1) # Concatenate, shape (..., 6) return d6 def rotation_matrix_to_axis_angle(matrix): """Convert rotation matrix to axis-angle representation. Args: matrix: Rotation matrix of shape (..., 3, 3) Returns: Angle in radians, shape (...) """ # Compute rotation angle from trace # trace(R) = 1 + 2*cos(theta) trace = matrix[..., 0, 0] + matrix[..., 1, 1] + matrix[..., 2, 2] # Clamp for numerical stability cos_angle = (trace - 1.0) / 2.0 cos_angle = torch.clamp(cos_angle, -1.0, 1.0) angle = torch.acos(cos_angle) return angle def geodesic_loss(pred_matrix, gt_matrix): """Compute geodesic loss on SO(3) manifold. Measures the angular distance between two rotation matrices. Loss = ||log(R_pred^T @ R_gt)||_F For efficiency, we use the trace formula: angle = arccos((trace(R_pred^T @ R_gt) - 1) / 2) Args: pred_matrix: Predicted rotation matrix, shape (B, 3, 3) gt_matrix: Ground truth rotation matrix, shape (B, 3, 3) Returns: Mean geodesic loss (scalar) """ # Compute relative rotation: R_pred^T @ R_gt relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix) # Compute angle from trace angle = rotation_matrix_to_axis_angle(relative_rotation) # Return mean squared angle (in radians) return (angle ** 2).mean() def rotation_angle_error(pred_matrix, gt_matrix): """Compute rotation error in degrees for evaluation. Args: pred_matrix: Predicted rotation matrix, shape (B, 3, 3) gt_matrix: Ground truth rotation matrix, shape (B, 3, 3) Returns: Rotation error in degrees, shape (B,) """ # Compute relative rotation relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix) # Compute angle angle_rad = rotation_matrix_to_axis_angle(relative_rotation) # Convert to degrees angle_deg = angle_rad * 180.0 / np.pi return angle_deg def batch_rotation_6d_to_matrix(d6_batch): """Batched version of rotation_6d_to_matrix for efficiency. Args: d6_batch: 6D rotation representation, shape (B, 6) Returns: Rotation matrices of shape (B, 3, 3) """ return rotation_6d_to_matrix(d6_batch) def test_rotation_conversion(): """Test rotation conversion functions.""" print("Testing rotation conversion...") # Test 1: Create valid rotation matrices using QR decomposition # This ensures we start with valid rotation matrices batch_size = 5 random_matrices = torch.randn(batch_size, 3, 3) # QR decomposition gives us orthogonal matrices Q, R_diag = torch.linalg.qr(random_matrices) # Ensure determinant is +1 (proper rotation, not reflection) det = torch.det(Q) Q = Q * det.unsqueeze(-1).unsqueeze(-1) # Flip if det=-1 R_gt = Q # Valid rotation matrices # Verify R_gt is valid det_gt = torch.det(R_gt) I = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3) ortho_error_gt = torch.norm(torch.bmm(R_gt.transpose(-2, -1), R_gt) - I, dim=(1, 2)) print(f"Ground truth check - Det: {det_gt.mean():.6f}, Ortho error: {ortho_error_gt.mean():.6f}") # Convert to 6D and back d6 = matrix_to_rotation_6d(R_gt) R_reconstructed = rotation_6d_to_matrix(d6) # Check reconstruction error error = torch.norm(R_gt - R_reconstructed, dim=(1, 2)) print(f"Reconstruction error: {error.mean().item():.6f} (should be ~0)") # Check orthogonality of reconstructed matrix ortho_error = torch.norm(torch.bmm(R_reconstructed.transpose(-2, -1), R_reconstructed) - I, dim=(1, 2)) print(f"Orthogonality error: {ortho_error.mean().item():.6f} (should be ~0)") # Check determinant (should be 1) det = torch.det(R_reconstructed) print(f"Determinant: {det.mean().item():.6f} (should be ~1)") # Test 2: Test that 6D conversion produces valid rotations random_6d = torch.randn(batch_size, 6) R_from_6d = rotation_6d_to_matrix(random_6d) det_from_6d = torch.det(R_from_6d) ortho_error_from_6d = torch.norm(torch.bmm(R_from_6d.transpose(-2, -1), R_from_6d) - I, dim=(1, 2)) print(f"\n6D->matrix check - Det: {det_from_6d.mean():.6f}, Ortho error: {ortho_error_from_6d.mean():.6f}") # Overall test result test_passed = ( error.mean() < 1e-5 and ortho_error.mean() < 1e-5 and abs(det.mean() - 1.0) < 1e-5 and ortho_error_from_6d.mean() < 1e-5 and abs(det_from_6d.mean() - 1.0) < 1e-5 ) print("\nTests passed!" if test_passed else "Tests FAILED!") if __name__ == "__main__": test_rotation_conversion()