Spaces:
Running on Zero
Running on Zero
| """ | |
| Rotation representation utilities for viewpoint prediction. | |
| Implements 6D rotation representation from: | |
| "On the Continuity of Rotation Representations in Neural Networks" | |
| Zhou et al., CVPR 2019 | |
| The 6D representation uses the first two columns of a rotation matrix, | |
| which is continuous and can be converted to a valid rotation matrix | |
| via Gram-Schmidt orthonormalization. | |
| """ | |
| import torch | |
| import numpy as np | |
| def normalize_vector(v, dim=-1, eps=1e-6): | |
| """Normalize vector to unit length with strong numerical protection. | |
| Args: | |
| v: Input vector of shape (..., D) | |
| dim: Dimension to normalize along | |
| eps: Small epsilon for numerical stability (increased to 1e-6 for float32) | |
| Returns: | |
| Normalized vector of same shape | |
| """ | |
| norm = torch.norm(v, dim=dim, keepdim=True) | |
| norm = torch.clamp(norm, min=eps) # Prevent division by near-zero | |
| return v / norm | |
| def rotation_6d_to_matrix(d6): | |
| """Convert 6D rotation representation to 3x3 rotation matrix. | |
| Uses Gram-Schmidt process to orthonormalize the first two columns | |
| and compute the third via cross product, with numerical stability checks. | |
| Args: | |
| d6: 6D rotation representation, shape (..., 6) | |
| First 3 values = first column | |
| Last 3 values = second column (not necessarily orthogonal) | |
| Returns: | |
| Rotation matrix of shape (..., 3, 3) | |
| """ | |
| # Extract first two columns | |
| a1 = d6[..., :3] # First column | |
| a2_unnormalized = d6[..., 3:] # Second column (not orthogonal yet) | |
| # Normalize first column with strong protection | |
| b1 = normalize_vector(a1, eps=1e-6) | |
| # Gram-Schmidt: make a2 orthogonal to b1 | |
| b2_unnormalized = a2_unnormalized - (b1 * a2_unnormalized).sum(dim=-1, keepdim=True) * b1 | |
| # Add minimum magnitude enforcement before normalizing | |
| b2_norm = torch.norm(b2_unnormalized, dim=-1, keepdim=True) | |
| b2_norm = torch.clamp(b2_norm, min=1e-6) | |
| b2 = b2_unnormalized / b2_norm | |
| # Compute third column via cross product | |
| b3 = torch.cross(b1, b2, dim=-1) | |
| # Stack into rotation matrix | |
| rotation_matrix = torch.stack([b1, b2, b3], dim=-1) | |
| return rotation_matrix | |
| def matrix_to_rotation_6d(matrix): | |
| """Convert 3x3 rotation matrix to 6D representation. | |
| Simply extracts the first two columns and concatenates them. | |
| Args: | |
| matrix: Rotation matrix of shape (..., 3, 3) | |
| Returns: | |
| 6D representation of shape (..., 6) | |
| """ | |
| # Extract first two columns | |
| # matrix[..., :, :2] gives shape (..., 3, 2) | |
| # We want to flatten column-wise: [col1[0], col1[1], col1[2], col2[0], col2[1], col2[2]] | |
| col1 = matrix[..., :, 0] # First column, shape (..., 3) | |
| col2 = matrix[..., :, 1] # Second column, shape (..., 3) | |
| d6 = torch.cat([col1, col2], dim=-1) # Concatenate, shape (..., 6) | |
| return d6 | |
| def rotation_matrix_to_axis_angle(matrix): | |
| """Convert rotation matrix to axis-angle representation. | |
| Args: | |
| matrix: Rotation matrix of shape (..., 3, 3) | |
| Returns: | |
| Angle in radians, shape (...) | |
| """ | |
| # Compute rotation angle from trace | |
| # trace(R) = 1 + 2*cos(theta) | |
| trace = matrix[..., 0, 0] + matrix[..., 1, 1] + matrix[..., 2, 2] | |
| # Clamp for numerical stability | |
| cos_angle = (trace - 1.0) / 2.0 | |
| cos_angle = torch.clamp(cos_angle, -1.0, 1.0) | |
| angle = torch.acos(cos_angle) | |
| return angle | |
| def geodesic_loss(pred_matrix, gt_matrix): | |
| """Compute geodesic loss on SO(3) manifold. | |
| Measures the angular distance between two rotation matrices. | |
| Loss = ||log(R_pred^T @ R_gt)||_F | |
| For efficiency, we use the trace formula: | |
| angle = arccos((trace(R_pred^T @ R_gt) - 1) / 2) | |
| Args: | |
| pred_matrix: Predicted rotation matrix, shape (B, 3, 3) | |
| gt_matrix: Ground truth rotation matrix, shape (B, 3, 3) | |
| Returns: | |
| Mean geodesic loss (scalar) | |
| """ | |
| # Compute relative rotation: R_pred^T @ R_gt | |
| relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix) | |
| # Compute angle from trace | |
| angle = rotation_matrix_to_axis_angle(relative_rotation) | |
| # Return mean squared angle (in radians) | |
| return (angle ** 2).mean() | |
| def rotation_angle_error(pred_matrix, gt_matrix): | |
| """Compute rotation error in degrees for evaluation. | |
| Args: | |
| pred_matrix: Predicted rotation matrix, shape (B, 3, 3) | |
| gt_matrix: Ground truth rotation matrix, shape (B, 3, 3) | |
| Returns: | |
| Rotation error in degrees, shape (B,) | |
| """ | |
| # Compute relative rotation | |
| relative_rotation = torch.bmm(pred_matrix.transpose(-2, -1), gt_matrix) | |
| # Compute angle | |
| angle_rad = rotation_matrix_to_axis_angle(relative_rotation) | |
| # Convert to degrees | |
| angle_deg = angle_rad * 180.0 / np.pi | |
| return angle_deg | |
| def batch_rotation_6d_to_matrix(d6_batch): | |
| """Batched version of rotation_6d_to_matrix for efficiency. | |
| Args: | |
| d6_batch: 6D rotation representation, shape (B, 6) | |
| Returns: | |
| Rotation matrices of shape (B, 3, 3) | |
| """ | |
| return rotation_6d_to_matrix(d6_batch) | |
| def test_rotation_conversion(): | |
| """Test rotation conversion functions.""" | |
| print("Testing rotation conversion...") | |
| # Test 1: Create valid rotation matrices using QR decomposition | |
| # This ensures we start with valid rotation matrices | |
| batch_size = 5 | |
| random_matrices = torch.randn(batch_size, 3, 3) | |
| # QR decomposition gives us orthogonal matrices | |
| Q, R_diag = torch.linalg.qr(random_matrices) | |
| # Ensure determinant is +1 (proper rotation, not reflection) | |
| det = torch.det(Q) | |
| Q = Q * det.unsqueeze(-1).unsqueeze(-1) # Flip if det=-1 | |
| R_gt = Q # Valid rotation matrices | |
| # Verify R_gt is valid | |
| det_gt = torch.det(R_gt) | |
| I = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3) | |
| ortho_error_gt = torch.norm(torch.bmm(R_gt.transpose(-2, -1), R_gt) - I, dim=(1, 2)) | |
| print(f"Ground truth check - Det: {det_gt.mean():.6f}, Ortho error: {ortho_error_gt.mean():.6f}") | |
| # Convert to 6D and back | |
| d6 = matrix_to_rotation_6d(R_gt) | |
| R_reconstructed = rotation_6d_to_matrix(d6) | |
| # Check reconstruction error | |
| error = torch.norm(R_gt - R_reconstructed, dim=(1, 2)) | |
| print(f"Reconstruction error: {error.mean().item():.6f} (should be ~0)") | |
| # Check orthogonality of reconstructed matrix | |
| ortho_error = torch.norm(torch.bmm(R_reconstructed.transpose(-2, -1), R_reconstructed) - I, dim=(1, 2)) | |
| print(f"Orthogonality error: {ortho_error.mean().item():.6f} (should be ~0)") | |
| # Check determinant (should be 1) | |
| det = torch.det(R_reconstructed) | |
| print(f"Determinant: {det.mean().item():.6f} (should be ~1)") | |
| # Test 2: Test that 6D conversion produces valid rotations | |
| random_6d = torch.randn(batch_size, 6) | |
| R_from_6d = rotation_6d_to_matrix(random_6d) | |
| det_from_6d = torch.det(R_from_6d) | |
| ortho_error_from_6d = torch.norm(torch.bmm(R_from_6d.transpose(-2, -1), R_from_6d) - I, dim=(1, 2)) | |
| print(f"\n6D->matrix check - Det: {det_from_6d.mean():.6f}, Ortho error: {ortho_error_from_6d.mean():.6f}") | |
| # Overall test result | |
| test_passed = ( | |
| error.mean() < 1e-5 and | |
| ortho_error.mean() < 1e-5 and | |
| abs(det.mean() - 1.0) < 1e-5 and | |
| ortho_error_from_6d.mean() < 1e-5 and | |
| abs(det_from_6d.mean() - 1.0) < 1e-5 | |
| ) | |
| print("\nTests passed!" if test_passed else "Tests FAILED!") | |
| if __name__ == "__main__": | |
| test_rotation_conversion() | |