Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py | |
| def quaternion_to_matrix( | |
| quaternions: Float[Tensor, "*batch 4"], | |
| eps: float = 1e-8, | |
| ) -> Float[Tensor, "*batch 3 3"]: | |
| # Order changed to match scipy format! | |
| i, j, k, r = torch.unbind(quaternions, dim=-1) | |
| two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) | |
| o = torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| -1, | |
| ) | |
| return rearrange(o, "... (i j) -> ... i j", i=3, j=3) | |
| def build_covariance( | |
| scale: Float[Tensor, "*#batch 3"], | |
| rotation_xyzw: Float[Tensor, "*#batch 4"], | |
| ) -> Float[Tensor, "*batch 3 3"]: | |
| scale = scale.diag_embed() | |
| rotation = quaternion_to_matrix(rotation_xyzw) | |
| return ( | |
| rotation | |
| ) | |