| | import numpy as np |
| | import torch |
| | from einops import repeat |
| | from jaxtyping import Float |
| | from scipy.spatial.transform import Rotation as R |
| | from torch import Tensor |
| |
|
| |
|
| | def generate_spin( |
| | num_frames: int, |
| | device: torch.device, |
| | elevation: float, |
| | radius: float, |
| | ) -> Float[Tensor, "frame 4 4"]: |
| | |
| | tf_translation = torch.eye(4, dtype=torch.float32, device=device) |
| | tf_translation[:2] *= -1 |
| | tf_translation[2, 3] = -radius |
| |
|
| | |
| | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) |
| | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) |
| |
|
| | azimuth = R.from_rotvec(rotation_vectors).as_matrix() |
| | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) |
| | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) |
| | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() |
| | tf_azimuth[:, :3, :3] = azimuth |
| |
|
| | |
| | deg_elevation = np.deg2rad(elevation) |
| | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) |
| | elevation = torch.tensor(elevation.as_matrix()) |
| | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) |
| | tf_elevation[:3, :3] = elevation |
| |
|
| | return tf_azimuth @ tf_elevation @ tf_translation |
| |
|