File size: 3,969 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar, Literal

import torch
from jaxtyping import Float, Int64
from torch import Tensor
from typeguard import value

from ...misc.step_tracker import StepTracker
from ..data_types import Stage

T = TypeVar("T")


@dataclass
class ViewSamplerCfg:
    name: Literal["base"]
    num_context_views: int
    num_target_views: int


class ViewSampler(ABC, Generic[T]):
    cfg: T
    stage: Stage
    is_overfitting: bool
    cameras_are_circular: bool
    step_tracker: StepTracker | None

    def __init__(
            self,
            cfg: T,
            stage: Stage,
            is_overfitting: bool,
            cameras_are_circular: bool,
            step_tracker: StepTracker | None,
    ) -> None:
        self.cfg = cfg
        self.stage = stage
        self.is_overfitting = is_overfitting
        self.cameras_are_circular = cameras_are_circular
        self.step_tracker = step_tracker

        self._all_context_indices = None
        self._all_target_indices = None

    @property
    def all_context_indices(self) -> Int64[Tensor, " context_view"]:
        return self._all_context_indices

    @property
    def context_indices(self) -> Int64[Tensor, " target_view"]:
        return self._all_context_indices

    @context_indices.setter
    def context_indices(self, indices: Int64[Tensor, " context_view"]):
        if self._all_context_indices is None:
            self._all_context_indices = indices
        else:
            raise RuntimeError("Context indices have already been set.")

    @property
    def target_indices(self) -> Int64[Tensor, " target_view"]:
        return self._all_target_indices

    @target_indices.setter
    def target_indices(self, indices: Int64[Tensor, " target_view"]):
        if self._all_target_indices is None:
            self._all_target_indices = indices
        else:
            raise RuntimeError("Target indices have already been set.")

    def sample_subset(self, extrinsics, intrinsics, device):
        pass

    @abstractmethod
    def _sample_impl(
            self,
            scene: str,
            extrinsics: Float[Tensor, "view 4 4"],
            intrinsics: Float[Tensor, "view 3 3"],
            device: torch.device = torch.device("cpu"),
            **kwargs,
    ) -> tuple[
        Int64[Tensor, " context_view"],  # indices for context views
        Int64[Tensor, " target_view"],  # indices for target views
    ]:
        pass

    def sample(
            self,
            scene: str,
            extrinsics: Float[Tensor, "view 4 4"],
            intrinsics: Float[Tensor, "view 3 3"],
            device: torch.device = torch.device("cpu"),
            **kwargs,
    ) -> tuple[
        Int64[Tensor, " context_view"],  # indices for context views
        Int64[Tensor, " target_view"],  # indices for target views
    ]:
        context_indices, target_indices = self._sample_impl(
            scene=scene,
            extrinsics=extrinsics,
            intrinsics=intrinsics,
            device=device,
            **kwargs,
        )
        # self.context_indices = context_indices
        # self.target_indices = target_indices

        return context_indices, target_indices

    @property
    @abstractmethod
    def num_target_views(self) -> int:
        pass

    @property
    @abstractmethod
    def num_context_views(self) -> int:
        pass

    @property
    def global_step(self) -> int:
        return 0 if self.step_tracker is None else self.step_tracker.get_step()

    def new_instance(self) -> "ViewSampler":
        """Create a new instance of the same ViewSampler class with the same configuration."""
        return value(self.__class__)(
            cfg=self.cfg,
            stage=self.stage,
            is_overfitting=self.is_overfitting,
            cameras_are_circular=self.cameras_are_circular,
            step_tracker=self.step_tracker,
        )