| import torch |
| from einops import rearrange |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
|
|
| |
| def quaternion_to_matrix( |
| quaternions: Float[Tensor, "*batch 4"], |
| eps: float = 1e-8, |
| ) -> Float[Tensor, "*batch 3 3"]: |
| |
| i, j, k, r = torch.unbind(quaternions, dim=-1) |
| two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) |
|
|
| o = torch.stack( |
| ( |
| 1 - two_s * (j * j + k * k), |
| two_s * (i * j - k * r), |
| two_s * (i * k + j * r), |
| two_s * (i * j + k * r), |
| 1 - two_s * (i * i + k * k), |
| two_s * (j * k - i * r), |
| two_s * (i * k - j * r), |
| two_s * (j * k + i * r), |
| 1 - two_s * (i * i + j * j), |
| ), |
| -1, |
| ) |
| return rearrange(o, "... (i j) -> ... i j", i=3, j=3) |
|
|
|
|
| def build_covariance( |
| scale: Float[Tensor, "*#batch 3"], |
| rotation_xyzw: Float[Tensor, "*#batch 4"], |
| ) -> Float[Tensor, "*batch 3 3"]: |
| scale = scale.diag_embed() |
| rotation = quaternion_to_matrix(rotation_xyzw) |
| return ( |
| rotation |
| @ scale |
| @ rearrange(scale, "... i j -> ... j i") |
| @ rearrange(rotation, "... i j -> ... j i") |
| ) |
|
|