File size: 1,658 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Any

from ...misc.step_tracker import StepTracker
from ..data_types import Stage
from .view_sampler import ViewSampler
from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg
from .view_sampler_ids import ViewSamplerIDs, ViewSamplerIDsCfg
from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg
from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg
from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg
from .view_sampler_bounded_v2 import ViewSamplerBoundedV2, ViewSamplerBoundedV2Cfg
from optgs.dataset.view_sampler.view_sampler_dense import ViewSamplerDense, ViewSamplerDenseCfg

VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = {
    "all": ViewSamplerAll,
    "ids": ViewSamplerIDs,
    "dense": ViewSamplerDense,  # colmap datasets
    "arbitrary": ViewSamplerArbitrary,
    "bounded": ViewSamplerBounded,
    "evaluation": ViewSamplerEvaluation,  # during evaluation
    "boundedv2": ViewSamplerBoundedV2,  # during training
}

ViewSamplerCfg = (
        ViewSamplerArbitraryCfg
        | ViewSamplerBoundedCfg
        | ViewSamplerEvaluationCfg
        | ViewSamplerAllCfg
        | ViewSamplerBoundedV2Cfg
        | ViewSamplerDenseCfg
        | ViewSamplerIDsCfg
)


def get_view_sampler(
        cfg: ViewSamplerCfg,
        stage: Stage,
        overfit: bool,
        cameras_are_circular: bool,
        step_tracker: StepTracker | None,
) -> ViewSampler[Any]:
    print("Using view sampler:", cfg.name)
    return VIEW_SAMPLERS[cfg.name](
        cfg,
        stage,
        overfit,
        cameras_are_circular,
        step_tracker,
    )