| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | from jaxtyping import Float, Int64 |
| | from torch import Tensor |
| |
|
| | from .view_sampler import ViewSampler |
| |
|
| |
|
| | @dataclass |
| | class ViewSamplerBoundedCfg: |
| | name: Literal["bounded"] |
| | num_context_views: int |
| | num_target_views: int |
| | min_distance_between_context_views: int |
| | max_distance_between_context_views: int |
| | min_distance_to_context_views: int |
| | warm_up_steps: int |
| | initial_min_distance_between_context_views: int |
| | initial_max_distance_between_context_views: int |
| |
|
| |
|
| | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): |
| | def schedule(self, initial: int, final: int) -> int: |
| | fraction = self.global_step / self.cfg.warm_up_steps |
| | return min(initial + int((final - initial) * fraction), final) |
| |
|
| | def sample( |
| | self, |
| | scene: str, |
| | extrinsics: Float[Tensor, "view 4 4"], |
| | intrinsics: Float[Tensor, "view 3 3"], |
| | device: torch.device = torch.device("cpu"), |
| | min_view_dist: int | None = None, |
| | max_view_dist: int | None = None, |
| | **kwargs, |
| | ) -> tuple[ |
| | Int64[Tensor, " context_view"], |
| | Int64[Tensor, " target_view"], |
| | ]: |
| | num_views, _, _ = extrinsics.shape |
| |
|
| | |
| | if self.stage == "test": |
| | |
| | max_gap = self.cfg.max_distance_between_context_views |
| | min_gap = self.cfg.max_distance_between_context_views |
| | elif self.cfg.warm_up_steps > 0: |
| | max_gap = self.schedule( |
| | self.cfg.initial_max_distance_between_context_views, |
| | self.cfg.max_distance_between_context_views, |
| | ) |
| | min_gap = self.schedule( |
| | self.cfg.initial_min_distance_between_context_views, |
| | self.cfg.min_distance_between_context_views, |
| | ) |
| | else: |
| | max_gap = self.cfg.max_distance_between_context_views |
| | min_gap = self.cfg.min_distance_between_context_views |
| |
|
| | |
| | if not self.cameras_are_circular: |
| | max_gap = min(num_views - 1, max_gap) |
| | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) |
| |
|
| | |
| | |
| | if min_view_dist is not None: |
| | min_gap = min_view_dist |
| |
|
| | if max_view_dist is not None: |
| | max_gap = max_view_dist |
| |
|
| | if max_gap < min_gap: |
| | raise ValueError("Example does not have enough frames!") |
| | context_gap = torch.randint( |
| | min_gap, |
| | max_gap + 1, |
| | size=tuple(), |
| | device=device, |
| | ).item() |
| |
|
| | |
| | index_context_left = torch.randint( |
| | num_views if self.cameras_are_circular else num_views - context_gap, |
| | size=tuple(), |
| | device=device, |
| | ).item() |
| | if self.stage == "test": |
| | index_context_left = index_context_left * 0 |
| | index_context_right = index_context_left + context_gap |
| |
|
| | if self.is_overfitting: |
| | index_context_left *= 0 |
| | index_context_right *= 0 |
| | index_context_right += max_gap |
| |
|
| | |
| | if self.stage == "test": |
| | |
| | index_target = torch.arange( |
| | index_context_left, |
| | index_context_right + 1, |
| | device=device, |
| | ) |
| | else: |
| | |
| | index_target = torch.randint( |
| | index_context_left + self.cfg.min_distance_to_context_views, |
| | index_context_right + 1 - self.cfg.min_distance_to_context_views, |
| | size=(self.cfg.num_target_views,), |
| | device=device, |
| | ) |
| |
|
| | |
| | if self.cameras_are_circular: |
| | index_target %= num_views |
| | index_context_right %= num_views |
| |
|
| | return ( |
| | torch.tensor((index_context_left, index_context_right)), |
| | index_target, |
| | ) |
| |
|
| | @property |
| | def num_context_views(self) -> int: |
| | return 2 |
| |
|
| | @property |
| | def num_target_views(self) -> int: |
| | return self.cfg.num_target_views |
| |
|