Spaces:
Runtime error
Runtime error
| from typing import Union | |
| import numpy | |
| import torch | |
| from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import ( | |
| TRANSFORMATION_AA_TO_SJA, | |
| TRANSFORMATION_SJA_TO_AA, | |
| ) | |
| from .logger import get_root_logger | |
| try: | |
| from pytorch3d.transforms import ( | |
| axis_angle_to_matrix, | |
| axis_angle_to_quaternion, | |
| euler_angles_to_matrix, | |
| matrix_to_euler_angles, | |
| matrix_to_quaternion, | |
| matrix_to_rotation_6d, | |
| quaternion_to_axis_angle, | |
| quaternion_to_matrix, | |
| rotation_6d_to_matrix, | |
| ) | |
| except (ImportError, ModuleNotFoundError): | |
| import traceback | |
| logger = get_root_logger() | |
| stack_str = '' | |
| for line in traceback.format_stack(): | |
| if 'frozen' not in line: | |
| stack_str += line + '\n' | |
| import_exception = traceback.format_exc() + '\n' | |
| warning_msg = stack_str + import_exception + \ | |
| 'If pytorch3d is not required,' +\ | |
| ' this warning could be ignored.' | |
| logger.warning(warning_msg) | |
| class Compose: | |
| def __init__(self, transforms: list): | |
| """Composes several transforms together. This transform does not | |
| support torchscript. | |
| Args: | |
| transforms (list): (list of transform functions) | |
| """ | |
| self.transforms = transforms | |
| def __call__(self, | |
| rotation: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz', | |
| **kwargs): | |
| convention = convention.lower() | |
| if not (set(convention) == set('xyz') and len(convention) == 3): | |
| raise ValueError(f'Invalid convention {convention}.') | |
| if isinstance(rotation, numpy.ndarray): | |
| data_type = 'numpy' | |
| rotation = torch.FloatTensor(rotation) | |
| elif isinstance(rotation, torch.Tensor): | |
| data_type = 'tensor' | |
| else: | |
| raise TypeError( | |
| 'Type of rotation should be torch.Tensor or numpy.ndarray') | |
| for t in self.transforms: | |
| if 'convention' in t.__code__.co_varnames: | |
| rotation = t(rotation, convention.upper(), **kwargs) | |
| else: | |
| rotation = t(rotation, **kwargs) | |
| if data_type == 'numpy': | |
| rotation = rotation.detach().cpu().numpy() | |
| return rotation | |
| def aa_to_rotmat( | |
| axis_angle: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """ | |
| Convert axis_angle to rotation matrixs. | |
| Args: | |
| axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
| """ | |
| if axis_angle.shape[-1] != 3: | |
| raise ValueError( | |
| f'Invalid input axis angles shape f{axis_angle.shape}.') | |
| t = Compose([axis_angle_to_matrix]) | |
| return t(axis_angle) | |
| def aa_to_quat( | |
| axis_angle: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """ | |
| Convert axis_angle to quaternions. | |
| Args: | |
| axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
| """ | |
| if axis_angle.shape[-1] != 3: | |
| raise ValueError(f'Invalid input axis angles f{axis_angle.shape}.') | |
| t = Compose([axis_angle_to_quaternion]) | |
| return t(axis_angle) | |
| def ee_to_rotmat(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
| convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert euler angle to rotation matrixs. | |
| Args: | |
| euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
| """ | |
| if euler_angle.shape[-1] != 3: | |
| raise ValueError( | |
| f'Invalid input euler angles shape f{euler_angle.shape}.') | |
| t = Compose([euler_angles_to_matrix]) | |
| return t(euler_angle, convention.upper()) | |
| def rotmat_to_ee( | |
| matrix: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation matrixs to euler angle. | |
| Args: | |
| matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3, 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
| raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
| t = Compose([matrix_to_euler_angles]) | |
| return t(matrix, convention.upper()) | |
| def rotmat_to_quat( | |
| matrix: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation matrixs to quaternions. | |
| Args: | |
| matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3, 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
| """ | |
| if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
| raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
| t = Compose([matrix_to_quaternion]) | |
| return t(matrix) | |
| def rotmat_to_rot6d( | |
| matrix: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation matrixs to rotation 6d representations. | |
| Args: | |
| matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3, 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
| raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
| t = Compose([matrix_to_rotation_6d]) | |
| return t(matrix) | |
| def quat_to_aa( | |
| quaternions: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert quaternions to axis angles. | |
| Args: | |
| quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if quaternions.shape[-1] != 4: | |
| raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
| t = Compose([quaternion_to_axis_angle]) | |
| return t(quaternions) | |
| def quat_to_rotmat( | |
| quaternions: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert quaternions to rotation matrixs. | |
| Args: | |
| quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
| """ | |
| if quaternions.shape[-1] != 4: | |
| raise ValueError( | |
| f'Invalid input quaternions shape f{quaternions.shape}.') | |
| t = Compose([quaternion_to_matrix]) | |
| return t(quaternions) | |
| def rot6d_to_rotmat( | |
| rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation 6d representations to rotation matrixs. | |
| Args: | |
| rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 6). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if rotation_6d.shape[-1] != 6: | |
| raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
| t = Compose([rotation_6d_to_matrix]) | |
| return t(rotation_6d) | |
| def aa_to_ee(axis_angle: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert axis angles to euler angle. | |
| Args: | |
| axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if axis_angle.shape[-1] != 3: | |
| raise ValueError( | |
| f'Invalid input axis_angle shape f{axis_angle.shape}.') | |
| t = Compose([axis_angle_to_matrix, matrix_to_euler_angles]) | |
| return t(axis_angle, convention) | |
| def aa_to_rot6d( | |
| axis_angle: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert axis angles to rotation 6d representations. | |
| Args: | |
| axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if axis_angle.shape[-1] != 3: | |
| raise ValueError(f'Invalid input axis_angle f{axis_angle.shape}.') | |
| t = Compose([axis_angle_to_matrix, matrix_to_rotation_6d]) | |
| return t(axis_angle) | |
| def ee_to_aa(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert euler angles to axis angles. | |
| Args: | |
| euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if euler_angle.shape[-1] != 3: | |
| raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
| t = Compose([ | |
| euler_angles_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle | |
| ]) | |
| return t(euler_angle, convention) | |
| def ee_to_quat(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
| convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert euler angles to quaternions. | |
| Args: | |
| euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
| """ | |
| if euler_angle.shape[-1] != 3: | |
| raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
| t = Compose([euler_angles_to_matrix, matrix_to_quaternion]) | |
| return t(euler_angle, convention) | |
| def ee_to_rot6d(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
| convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert euler angles to rotation 6d representation. | |
| Args: | |
| euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if euler_angle.shape[-1] != 3: | |
| raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
| t = Compose([euler_angles_to_matrix, matrix_to_rotation_6d]) | |
| return t(euler_angle, convention) | |
| def rotmat_to_aa( | |
| matrix: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation matrixs to axis angles. | |
| Args: | |
| matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 3, 3). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
| raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
| t = Compose([matrix_to_quaternion, quaternion_to_axis_angle]) | |
| return t(matrix) | |
| def quat_to_ee(quaternions: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert quaternions to euler angles. | |
| Args: | |
| quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 4). ndim of input is unlimited. | |
| convention (str, optional): Convention string of three letters | |
| from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| if quaternions.shape[-1] != 4: | |
| raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
| t = Compose([quaternion_to_matrix, matrix_to_euler_angles]) | |
| return t(quaternions, convention) | |
| def quat_to_rot6d( | |
| quaternions: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert quaternions to rotation 6d representations. | |
| Args: | |
| quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 4). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if quaternions.shape[-1] != 4: | |
| raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
| t = Compose([quaternion_to_matrix, matrix_to_rotation_6d]) | |
| return t(quaternions) | |
| def rot6d_to_aa( | |
| rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation 6d representations to axis angles. | |
| Args: | |
| rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 6). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if rotation_6d.shape[-1] != 6: | |
| raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
| t = Compose([ | |
| rotation_6d_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle | |
| ]) | |
| return t(rotation_6d) | |
| def rot6d_to_ee(rotation_6d: Union[torch.Tensor, numpy.ndarray], | |
| convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation 6d representations to euler angles. | |
| Args: | |
| rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 6). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if rotation_6d.shape[-1] != 6: | |
| raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
| t = Compose([rotation_6d_to_matrix, matrix_to_euler_angles]) | |
| return t(rotation_6d, convention) | |
| def rot6d_to_quat( | |
| rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert rotation 6d representations to quaternions. | |
| Args: | |
| rotation (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 6). ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| if rotation_6d.shape[-1] != 6: | |
| raise ValueError( | |
| f'Invalid input rotation_6d shape f{rotation_6d.shape}.') | |
| t = Compose([rotation_6d_to_matrix, matrix_to_quaternion]) | |
| return t(rotation_6d) | |
| def aa_to_sja( | |
| axis_angle: Union[torch.Tensor, numpy.ndarray], | |
| R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA, | |
| R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert axis-angles to standard joint angles. | |
| Args: | |
| axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3), ndim of input is unlimited. | |
| R_t (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3, 3). Transformation matrices from | |
| original axis-angle coordinate system to | |
| standard joint angle coordinate system, | |
| ndim of input is unlimited. | |
| R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3, 3). Transformation matrices from | |
| standard joint angle coordinate system to | |
| original axis-angle coordinate system, | |
| ndim of input is unlimited. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| def _aa_to_sja(aa, R_t, R_t_inv): | |
| R_aa = axis_angle_to_matrix(aa) | |
| R_sja = R_t @ R_aa @ R_t_inv | |
| sja = matrix_to_euler_angles(R_sja, convention='XYZ') | |
| return sja | |
| if axis_angle.shape[-2:] != (21, 3): | |
| raise ValueError( | |
| f'Invalid input axis angles shape f{axis_angle.shape}.') | |
| if R_t.shape[-3:] != (21, 3, 3): | |
| raise ValueError(f'Invalid input R_t shape f{R_t.shape}.') | |
| if R_t_inv.shape[-3:] != (21, 3, 3): | |
| raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.') | |
| t = Compose([_aa_to_sja]) | |
| return t(axis_angle, R_t=R_t, R_t_inv=R_t_inv) | |
| def sja_to_aa( | |
| sja: Union[torch.Tensor, numpy.ndarray], | |
| R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA, | |
| R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA | |
| ) -> Union[torch.Tensor, numpy.ndarray]: | |
| """Convert standard joint angles to axis angles. | |
| Args: | |
| sja (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3). ndim of input is unlimited. | |
| R_t (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3, 3). Transformation matrices from | |
| original axis-angle coordinate system to | |
| standard joint angle coordinate system | |
| R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape | |
| should be (..., 21, 3, 3). Transformation matrices from | |
| standard joint angle coordinate system to | |
| original axis-angle coordinate system | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
| """ | |
| def _sja_to_aa(sja, R_t, R_t_inv): | |
| R_sja = euler_angles_to_matrix(sja, convention='XYZ') | |
| R_aa = R_t_inv @ R_sja @ R_t | |
| aa = quaternion_to_axis_angle(matrix_to_quaternion(R_aa)) | |
| return aa | |
| if sja.shape[-2:] != (21, 3): | |
| raise ValueError(f'Invalid input axis angles shape f{sja.shape}.') | |
| if R_t.shape[-3:] != (21, 3, 3): | |
| raise ValueError(f'Invalid input R_t shape f{R_t.shape}.') | |
| if R_t_inv.shape[-3:] != (21, 3, 3): | |
| raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.') | |
| t = Compose([_sja_to_aa]) | |
| return t(sja, R_t=R_t, R_t_inv=R_t_inv) | |
| def make_homegeneous_rotmat_batch(input: torch.Tensor) -> torch.Tensor: | |
| """Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor. | |
| Parameters | |
| ---------- | |
| :param input: A tensor of dimensions batch size x 3 x 4 | |
| :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) | |
| """ | |
| batch_size = input.shape[0] | |
| row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float) | |
| row_append.requires_grad = False | |
| padded_tensor = torch.cat( | |
| [input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], dim=1) | |
| return padded_tensor | |
| def make_homegeneous_rotmat(input: torch.Tensor) -> torch.Tensor: | |
| """Appends a row of [0,0,0,1] to a 3 x 4 Tensor. | |
| Parameters | |
| ---------- | |
| :param input: A tensor of dimensions 3 x 4 | |
| :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) | |
| """ | |
| row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float) | |
| row_append.requires_grad = False | |
| padded_tensor = torch.cat(input, row_append, dim=1) | |
| return padded_tensor | |