Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal | |
| import numpy as np | |
| import torch | |
| from jaxtyping import Float, Int64 | |
| from torch import Tensor | |
| from .view_sampler import ViewSampler, ViewSamplerCfg | |
| class ViewSamplerDenseCfg(ViewSamplerCfg): | |
| name: Literal["dense"] | |
| target_every: int | |
| context_every: int | |
| sample_views_strategy: Literal["random", "neighbors"] = "random" | |
| def __post_init__(self): | |
| assert (self.target_every > 0) != (self.context_every > 0), \ | |
| "Either target_every or context_every must be set, but not both." | |
| class ViewSamplerDense(ViewSampler[ViewSamplerDenseCfg]): | |
| def _sample_impl( | |
| self, | |
| scene: str, | |
| extrinsics: Float[Tensor, "view 4 4"], | |
| intrinsics: Float[Tensor, "view 3 3"], | |
| device: torch.device = torch.device("cpu"), | |
| **kwargs, | |
| ) -> tuple[ | |
| Int64[Tensor, " context_view"], # indices for context views | |
| Int64[Tensor, " target_view"], # indices for target views | |
| ]: | |
| """Sample context and target views.""" | |
| num_views, _, _ = extrinsics.shape | |
| all_views = torch.arange(num_views, device=device) | |
| if self.cfg.target_every > 0: | |
| target_views = all_views[::self.cfg.target_every] | |
| context_views = set(all_views.tolist()) - set(target_views.tolist()) | |
| context_views = torch.tensor(list(context_views), device=device) | |
| elif self.cfg.context_every > 0: | |
| context_views = all_views[::self.cfg.context_every] | |
| target_views = set(all_views.tolist()) - set(context_views.tolist()) | |
| target_views = torch.tensor(list(target_views), device=device) | |
| else: | |
| raise ValueError("Either target_every or context_every must be set to a positive integer.") | |
| def sample_views(extrinsics, index_views, num_views_to_sample: int, strategy: str, | |
| center_idx: int | None = None) -> Tensor: | |
| if num_views_to_sample == -1 or num_views_to_sample >= len(index_views): | |
| return index_views | |
| if strategy == "random": | |
| return index_views[torch.randperm(len(index_views))[:num_views_to_sample]] | |
| elif strategy == "neighbors": | |
| raise NotImplementedError | |
| # Choose a random center view and choose views around it, based on cameras extrinsics | |
| if center_idx is None: | |
| center_idx = np.random.choice( | |
| len(index_views), | |
| size=1, | |
| replace=False | |
| )[0] | |
| # Calculate distances to the center view | |
| rotations = extrinsics[:, :3, :3] # [V, 3, 3] | |
| # Calculate camera center as -R^T * t | |
| translation = extrinsics[:, :3, [3]] # [V, 3, 1] | |
| # poses = -rotations.transpose(1, 2) @ translation # [V, 3, 1] | |
| poses = translation # [V, 3, 1] | |
| center_pose = poses[center_idx] # [3, 1] | |
| # Calculate Euclidean distances to the center view | |
| dists = torch.norm(poses - center_pose.unsqueeze(0), dim=1)[0] # [V] | |
| # Calculate angular differences to the center view | |
| center_rot = extrinsics[center_idx, :3, :3] # [3, 3] | |
| # Compute rotation difference | |
| rot_diffs = torch.matmul(rotations, center_rot.transpose(0, 1)) # [V, 3, 3] | |
| # Compute angles from rotation matrices | |
| cos_angles = (rot_diffs[:, 0, 0] + rot_diffs[:, 1, 1] + rot_diffs[:, 2, 2] - 1) / 2 # [V] | |
| cos_angles = torch.clamp(cos_angles, -1.0, 1.0) # Numerical stability | |
| angles = torch.acos(cos_angles) # [V] | |
| # Combine distance and angle into a single metric | |
| combined_metric = dists + angles # [V] | |
| # Get the indices of the nearest neighbors | |
| combined_metric = combined_metric[index_views] | |
| sorted_indices = torch.argsort(combined_metric) | |
| return index_views[sorted_indices[:num_views_to_sample]] | |
| else: | |
| raise ValueError(f"Unknown sampling strategy: {strategy}") | |
| index_context = sample_views(extrinsics, context_views, self.cfg.num_context_views, | |
| self.cfg.sample_views_strategy) | |
| index_target = sample_views(extrinsics, target_views, self.cfg.num_target_views, self.cfg.sample_views_strategy, | |
| center_idx=index_context[0].item()) | |
| return index_context, index_target | |
| def num_context_views(self) -> int: | |
| return self.cfg.num_context_views | |
| def num_target_views(self) -> int: | |
| return self.cfg.num_target_views | |