| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | from jaxtyping import Float, Int64 |
| | from torch import Tensor |
| |
|
| | from .view_sampler import ViewSampler |
| |
|
| |
|
| | @dataclass |
| | class ViewSamplerBoundedCfg: |
| | name: Literal["bounded"] |
| | num_context_views: int |
| | num_target_views: int |
| | min_distance_between_context_views: int |
| | max_distance_between_context_views: int |
| | min_distance_to_context_views: int |
| | warm_up_steps: int |
| | initial_min_distance_between_context_views: int |
| | initial_max_distance_between_context_views: int |
| | max_img_per_gpu: int |
| | min_gap_multiplier: int |
| | max_gap_multiplier: int |
| |
|
| | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): |
| | def schedule(self, initial: int, final: int) -> int: |
| | fraction = self.global_step / self.cfg.warm_up_steps |
| | return min(initial + int((final - initial) * fraction), final) |
| |
|
| | def sample( |
| | self, |
| | scene: str, |
| | num_context_views: int, |
| | extrinsics: Float[Tensor, "view 4 4"], |
| | intrinsics: Float[Tensor, "view 3 3"], |
| | device: torch.device = torch.device("cpu"), |
| | ) -> tuple[ |
| | Int64[Tensor, " context_view"], |
| | Int64[Tensor, " target_view"], |
| | Float[Tensor, " overlap"], |
| | ]: |
| | num_views, _, _ = extrinsics.shape |
| |
|
| | |
| | if self.stage == "test": |
| | |
| | max_gap = self.cfg.max_distance_between_context_views |
| | min_gap = self.cfg.max_distance_between_context_views |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views] |
| | max_gap = min(max_gap, num_views-1) |
| | |
| | if not self.cameras_are_circular: |
| | max_gap = min(num_views - 1, max_gap) |
| | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) |
| | if max_gap < min_gap: |
| | raise ValueError("Example does not have enough frames!") |
| | context_gap = torch.randint( |
| | min_gap, |
| | max_gap + 1, |
| | size=tuple(), |
| | device=device, |
| | ).item() |
| |
|
| | |
| | index_context_left = torch.randint( |
| | num_views if self.cameras_are_circular else num_views - context_gap, |
| | size=tuple(), |
| | device=device, |
| | ).item() |
| | if self.stage == "test": |
| | index_context_left = index_context_left * 0 |
| | index_context_right = index_context_left + context_gap |
| |
|
| | if self.is_overfitting: |
| | index_context_left *= 0 |
| | index_context_right *= 0 |
| | index_context_right += max_gap |
| |
|
| | |
| | if self.stage == "test": |
| | |
| | index_target = torch.arange( |
| | index_context_left, |
| | index_context_right + 1, |
| | device=device, |
| | ) |
| | else: |
| | |
| | index_target = torch.randint( |
| | index_context_left + self.cfg.min_distance_to_context_views, |
| | index_context_right + 1 - self.cfg.min_distance_to_context_views, |
| | size=(self.cfg.num_target_views,), |
| | device=device, |
| | ) |
| |
|
| | |
| | if self.cameras_are_circular: |
| | index_target %= num_views |
| | index_context_right %= num_views |
| | |
| | |
| | |
| | if num_context_views > 2: |
| | num_extra_views = num_context_views - 2 |
| | extra_views = [] |
| | while len(set(extra_views)) != num_extra_views: |
| | extra_views = torch.randint( |
| | index_context_left + 1, |
| | index_context_right, |
| | (num_extra_views,), |
| | ).tolist() |
| | else: |
| | extra_views = [] |
| |
|
| | overlap = torch.tensor([0.5], dtype=torch.float32, device=device) |
| |
|
| | return ( |
| | torch.tensor((index_context_left, *extra_views, index_context_right)), |
| | index_target, |
| | overlap |
| | ) |
| |
|
| | @property |
| | def num_context_views(self) -> int: |
| | return self.cfg.num_context_views |
| |
|
| | @property |
| | def num_target_views(self) -> int: |
| | return self.cfg.num_target_views |
| | |
| | @property |
| | def num_ctxt_gap_mapping(self) -> dict: |
| | mapping = dict() |
| | for num_ctxt in range(2, self.cfg.num_context_views + 1): |
| | mapping[num_ctxt] = [min(num_ctxt * self.cfg.min_gap_multiplier, self.cfg.min_distance_between_context_views), |
| | min(max(num_ctxt * self.cfg.max_gap_multiplier, num_ctxt ** 2), self.cfg.max_distance_between_context_views)] |
| | return mapping |
| |
|