from dataclasses import dataclass from typing import Literal import torch from jaxtyping import Float, Int64 from torch import Tensor from .view_sampler import ViewSampler @dataclass class ViewSamplerBoundedCfg: name: Literal["bounded"] num_context_views: int num_target_views: int min_distance_between_context_views: int max_distance_between_context_views: int min_distance_to_context_views: int warm_up_steps: int initial_min_distance_between_context_views: int initial_max_distance_between_context_views: int class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): def schedule(self, initial: int, final: int) -> int: fraction = self.global_step / self.cfg.warm_up_steps return min(initial + int((final - initial) * fraction), final) def sample( self, scene: str, extrinsics: Float[Tensor, "view 4 4"], intrinsics: Float[Tensor, "view 3 3"], device: torch.device = torch.device("cpu"), min_view_dist: int | None = None, max_view_dist: int | None = None, **kwargs, ) -> tuple[ Int64[Tensor, " context_view"], # indices for context views Int64[Tensor, " target_view"], # indices for target views ]: num_views, _, _ = extrinsics.shape # Compute the context view spacing based on the current global step. if self.stage == "test": # When testing, always use the full gap. max_gap = self.cfg.max_distance_between_context_views min_gap = self.cfg.max_distance_between_context_views elif self.cfg.warm_up_steps > 0: max_gap = self.schedule( self.cfg.initial_max_distance_between_context_views, self.cfg.max_distance_between_context_views, ) min_gap = self.schedule( self.cfg.initial_min_distance_between_context_views, self.cfg.min_distance_between_context_views, ) else: max_gap = self.cfg.max_distance_between_context_views min_gap = self.cfg.min_distance_between_context_views # Pick the gap between the context views. if not self.cameras_are_circular: max_gap = min(num_views - 1, max_gap) min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) # overwrite min_gap and max_gap, useful for mixed dataset training # use different view distance for different dataset if min_view_dist is not None: min_gap = min_view_dist if max_view_dist is not None: max_gap = max_view_dist if max_gap < min_gap: raise ValueError("Example does not have enough frames!") context_gap = torch.randint( min_gap, max_gap + 1, size=tuple(), device=device, ).item() # Pick the left and right context indices. index_context_left = torch.randint( num_views if self.cameras_are_circular else num_views - context_gap, size=tuple(), device=device, ).item() if self.stage == "test": index_context_left = index_context_left * 0 index_context_right = index_context_left + context_gap if self.is_overfitting: index_context_left *= 0 index_context_right *= 0 index_context_right += max_gap # Pick the target view indices. if self.stage == "test": # When testing, pick all. index_target = torch.arange( index_context_left, index_context_right + 1, device=device, ) else: # When training or validating (visualizing), pick at random. index_target = torch.randint( index_context_left + self.cfg.min_distance_to_context_views, index_context_right + 1 - self.cfg.min_distance_to_context_views, size=(self.cfg.num_target_views,), device=device, ) # Apply modulo for circular datasets. if self.cameras_are_circular: index_target %= num_views index_context_right %= num_views return ( torch.tensor((index_context_left, index_context_right)), index_target, ) @property def num_context_views(self) -> int: return 2 @property def num_target_views(self) -> int: return self.cfg.num_target_views