Spaces:
Sleeping
Sleeping
File size: 3,969 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar, Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from typeguard import value
from ...misc.step_tracker import StepTracker
from ..data_types import Stage
T = TypeVar("T")
@dataclass
class ViewSamplerCfg:
name: Literal["base"]
num_context_views: int
num_target_views: int
class ViewSampler(ABC, Generic[T]):
cfg: T
stage: Stage
is_overfitting: bool
cameras_are_circular: bool
step_tracker: StepTracker | None
def __init__(
self,
cfg: T,
stage: Stage,
is_overfitting: bool,
cameras_are_circular: bool,
step_tracker: StepTracker | None,
) -> None:
self.cfg = cfg
self.stage = stage
self.is_overfitting = is_overfitting
self.cameras_are_circular = cameras_are_circular
self.step_tracker = step_tracker
self._all_context_indices = None
self._all_target_indices = None
@property
def all_context_indices(self) -> Int64[Tensor, " context_view"]:
return self._all_context_indices
@property
def context_indices(self) -> Int64[Tensor, " target_view"]:
return self._all_context_indices
@context_indices.setter
def context_indices(self, indices: Int64[Tensor, " context_view"]):
if self._all_context_indices is None:
self._all_context_indices = indices
else:
raise RuntimeError("Context indices have already been set.")
@property
def target_indices(self) -> Int64[Tensor, " target_view"]:
return self._all_target_indices
@target_indices.setter
def target_indices(self, indices: Int64[Tensor, " target_view"]):
if self._all_target_indices is None:
self._all_target_indices = indices
else:
raise RuntimeError("Target indices have already been set.")
def sample_subset(self, extrinsics, intrinsics, device):
pass
@abstractmethod
def _sample_impl(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
**kwargs,
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
]:
pass
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
**kwargs,
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
]:
context_indices, target_indices = self._sample_impl(
scene=scene,
extrinsics=extrinsics,
intrinsics=intrinsics,
device=device,
**kwargs,
)
# self.context_indices = context_indices
# self.target_indices = target_indices
return context_indices, target_indices
@property
@abstractmethod
def num_target_views(self) -> int:
pass
@property
@abstractmethod
def num_context_views(self) -> int:
pass
@property
def global_step(self) -> int:
return 0 if self.step_tracker is None else self.step_tracker.get_step()
def new_instance(self) -> "ViewSampler":
"""Create a new instance of the same ViewSampler class with the same configuration."""
return value(self.__class__)(
cfg=self.cfg,
stage=self.stage,
is_overfitting=self.is_overfitting,
cameras_are_circular=self.cameras_are_circular,
step_tracker=self.step_tracker,
)
|