Spaces:
Runtime error
Runtime error
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Literal | |
| import torch | |
| from dacite import Config, from_dict | |
| from jaxtyping import Float, Int64 | |
| from torch import Tensor | |
| from ...evaluation.evaluation_index_generator import IndexEntry | |
| from ...global_cfg import get_cfg | |
| from ...misc.step_tracker import StepTracker | |
| from ..types import Stage | |
| from .three_view_hack import add_third_context_index | |
| from .view_sampler import ViewSampler | |
| class ViewSamplerEvaluationCfg: | |
| name: Literal["evaluation"] | |
| index_path: Path | |
| num_context_views: int | |
| class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): | |
| index: dict[str, IndexEntry | None] | |
| def __init__( | |
| self, | |
| cfg: ViewSamplerEvaluationCfg, | |
| stage: Stage, | |
| is_overfitting: bool, | |
| cameras_are_circular: bool, | |
| step_tracker: StepTracker | None, | |
| ) -> None: | |
| super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) | |
| dacite_config = Config(cast=[tuple]) | |
| with cfg.index_path.open("r") as f: | |
| self.index = { | |
| k: None if v is None else from_dict(IndexEntry, v, dacite_config) | |
| for k, v in json.load(f).items() | |
| } | |
| def sample( | |
| self, | |
| scene: str, | |
| extrinsics: Float[Tensor, "view 4 4"], | |
| intrinsics: Float[Tensor, "view 3 3"], | |
| device: torch.device = torch.device("cpu"), | |
| **kwargs, | |
| ) -> tuple[ | |
| Int64[Tensor, " context_view"], # indices for context views | |
| Int64[Tensor, " target_view"], # indices for target views | |
| ]: | |
| entry = self.index.get(scene) | |
| if entry is None: | |
| raise ValueError(f"No indices available for scene {scene}.") | |
| context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) | |
| target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) | |
| return context_indices, target_indices, torch.zeros(1) | |
| def num_context_views(self) -> int: | |
| return 0 | |
| def num_target_views(self) -> int: | |
| return 0 | |