| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | from jaxtyping import Float, Int64 |
| | from torch import Tensor |
| |
|
| | from .three_view_hack import add_third_context_index |
| | from .view_sampler import ViewSampler |
| |
|
| |
|
| | @dataclass |
| | class ViewSamplerArbitraryCfg: |
| | name: Literal["arbitrary"] |
| | num_context_views: int |
| | num_target_views: int |
| | context_views: list[int] | None |
| | target_views: list[int] | None |
| |
|
| |
|
| | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): |
| | def sample( |
| | self, |
| | scene: str, |
| | extrinsics: Float[Tensor, "view 4 4"], |
| | intrinsics: Float[Tensor, "view 3 3"], |
| | device: torch.device = torch.device("cpu"), |
| | ) -> tuple[ |
| | Int64[Tensor, " context_view"], |
| | Int64[Tensor, " target_view"], |
| | Float[Tensor, " overlap"], |
| | ]: |
| | """Arbitrarily sample context and target views.""" |
| | num_views, _, _ = extrinsics.shape |
| |
|
| | index_context = torch.randint( |
| | 0, |
| | num_views, |
| | size=(self.cfg.num_context_views,), |
| | device=device, |
| | ) |
| |
|
| | |
| | if self.cfg.context_views is not None: |
| | index_context = torch.tensor( |
| | self.cfg.context_views, dtype=torch.int64, device=device |
| | ) |
| |
|
| | if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2: |
| | index_context = add_third_context_index(index_context) |
| | else: |
| | assert len(self.cfg.context_views) == self.cfg.num_context_views |
| | index_target = torch.randint( |
| | 0, |
| | num_views, |
| | size=(self.cfg.num_target_views,), |
| | device=device, |
| | ) |
| |
|
| | |
| | if self.cfg.target_views is not None: |
| | assert len(self.cfg.target_views) == self.cfg.num_target_views |
| | index_target = torch.tensor( |
| | self.cfg.target_views, dtype=torch.int64, device=device |
| | ) |
| |
|
| | overlap = torch.tensor([0.5], dtype=torch.float32, device=device) |
| |
|
| | return index_context, index_target, overlap |
| |
|
| | @property |
| | def num_context_views(self) -> int: |
| | return self.cfg.num_context_views |
| |
|
| | @property |
| | def num_target_views(self) -> int: |
| | return self.cfg.num_target_views |
| |
|