File size: 4,572 Bytes
a6dd040 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerBoundedCfg:
name: Literal["bounded"]
num_context_views: int
num_target_views: int
min_distance_between_context_views: int
max_distance_between_context_views: int
min_distance_to_context_views: int
warm_up_steps: int
initial_min_distance_between_context_views: int
initial_max_distance_between_context_views: int
class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]):
def schedule(self, initial: int, final: int) -> int:
fraction = self.global_step / self.cfg.warm_up_steps
return min(initial + int((final - initial) * fraction), final)
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
min_view_dist: int | None = None,
max_view_dist: int | None = None,
**kwargs,
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
]:
num_views, _, _ = extrinsics.shape
# Compute the context view spacing based on the current global step.
if self.stage == "test":
# When testing, always use the full gap.
max_gap = self.cfg.max_distance_between_context_views
min_gap = self.cfg.max_distance_between_context_views
elif self.cfg.warm_up_steps > 0:
max_gap = self.schedule(
self.cfg.initial_max_distance_between_context_views,
self.cfg.max_distance_between_context_views,
)
min_gap = self.schedule(
self.cfg.initial_min_distance_between_context_views,
self.cfg.min_distance_between_context_views,
)
else:
max_gap = self.cfg.max_distance_between_context_views
min_gap = self.cfg.min_distance_between_context_views
# Pick the gap between the context views.
if not self.cameras_are_circular:
max_gap = min(num_views - 1, max_gap)
min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap)
# overwrite min_gap and max_gap, useful for mixed dataset training
# use different view distance for different dataset
if min_view_dist is not None:
min_gap = min_view_dist
if max_view_dist is not None:
max_gap = max_view_dist
if max_gap < min_gap:
raise ValueError("Example does not have enough frames!")
context_gap = torch.randint(
min_gap,
max_gap + 1,
size=tuple(),
device=device,
).item()
# Pick the left and right context indices.
index_context_left = torch.randint(
num_views if self.cameras_are_circular else num_views - context_gap,
size=tuple(),
device=device,
).item()
if self.stage == "test":
index_context_left = index_context_left * 0
index_context_right = index_context_left + context_gap
if self.is_overfitting:
index_context_left *= 0
index_context_right *= 0
index_context_right += max_gap
# Pick the target view indices.
if self.stage == "test":
# When testing, pick all.
index_target = torch.arange(
index_context_left,
index_context_right + 1,
device=device,
)
else:
# When training or validating (visualizing), pick at random.
index_target = torch.randint(
index_context_left + self.cfg.min_distance_to_context_views,
index_context_right + 1 - self.cfg.min_distance_to_context_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Apply modulo for circular datasets.
if self.cameras_are_circular:
index_target %= num_views
index_context_right %= num_views
return (
torch.tensor((index_context_left, index_context_right)),
index_target,
)
@property
def num_context_views(self) -> int:
return 2
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
|