Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from einops import repeat | |
| from jaxtyping import Float | |
| from scipy.spatial.transform import Rotation as R | |
| from torch import Tensor | |
| def generate_spin( | |
| num_frames: int, | |
| device: torch.device, | |
| elevation: float, | |
| radius: float, | |
| ) -> Float[Tensor, "frame 4 4"]: | |
| # Translate back along the camera's look vector. | |
| tf_translation = torch.eye(4, dtype=torch.float32, device=device) | |
| tf_translation[:2] *= -1 | |
| tf_translation[2, 3] = -radius | |
| # Generate the transformation for the azimuth. | |
| phi = 2 * np.pi * (np.arange(num_frames) / num_frames) | |
| rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) | |
| azimuth = R.from_rotvec(rotation_vectors).as_matrix() | |
| azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) | |
| tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) | |
| tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() | |
| tf_azimuth[:, :3, :3] = azimuth | |
| # Generate the transformation for the elevation. | |
| deg_elevation = np.deg2rad(elevation) | |
| elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) | |
| elevation = torch.tensor(elevation.as_matrix()) | |
| tf_elevation = torch.eye(4, dtype=torch.float32, device=device) | |
| tf_elevation[:3, :3] = elevation | |
| return tf_azimuth @ tf_elevation @ tf_translation | |