Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| def generate_wobble_transformation( | |
| radius: Float[Tensor, "*#batch"], | |
| t: Float[Tensor, " time_step"], | |
| num_rotations: int = 1, | |
| scale_radius_with_t: bool = True, | |
| ) -> Float[Tensor, "*batch time_step 4 4"]: | |
| # Generate a translation in the image plane. | |
| tf = torch.eye(4, dtype=torch.float32, device=t.device) | |
| tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() | |
| radius = radius[..., None] | |
| if scale_radius_with_t: | |
| radius = radius * t | |
| tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius | |
| tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius | |
| return tf | |
| def generate_wobble( | |
| extrinsics: Float[Tensor, "*#batch 4 4"], | |
| radius: Float[Tensor, "*#batch"], | |
| t: Float[Tensor, " time_step"], | |
| ) -> Float[Tensor, "*batch time_step 4 4"]: | |
| tf = generate_wobble_transformation(radius, t) | |
| return rearrange(extrinsics, "... i j -> ... () i j") @ tf | |