File size: 4,885 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
from jaxtyping import Float, Int64
from torch import Tensor

from .view_sampler import ViewSampler, ViewSamplerCfg


@dataclass
class ViewSamplerDenseCfg(ViewSamplerCfg):
    name: Literal["dense"]
    target_every: int
    context_every: int

    sample_views_strategy: Literal["random", "neighbors"] = "random"

    def __post_init__(self):
        assert (self.target_every > 0) != (self.context_every > 0), \
            "Either target_every or context_every must be set, but not both."


class ViewSamplerDense(ViewSampler[ViewSamplerDenseCfg]):

    def _sample_impl(
            self,
            scene: str,
            extrinsics: Float[Tensor, "view 4 4"],
            intrinsics: Float[Tensor, "view 3 3"],
            device: torch.device = torch.device("cpu"),
            **kwargs,
    ) -> tuple[
        Int64[Tensor, " context_view"],  # indices for context views
        Int64[Tensor, " target_view"],  # indices for target views
    ]:
        """Sample context and target views."""
        num_views, _, _ = extrinsics.shape

        all_views = torch.arange(num_views, device=device)

        if self.cfg.target_every > 0:
            target_views = all_views[::self.cfg.target_every]
            context_views = set(all_views.tolist()) - set(target_views.tolist())
            context_views = torch.tensor(list(context_views), device=device)
        elif self.cfg.context_every > 0:
            context_views = all_views[::self.cfg.context_every]
            target_views = set(all_views.tolist()) - set(context_views.tolist())
            target_views = torch.tensor(list(target_views), device=device)
        else:
            raise ValueError("Either target_every or context_every must be set to a positive integer.")

        def sample_views(extrinsics, index_views, num_views_to_sample: int, strategy: str,
                         center_idx: int | None = None) -> Tensor:
            if num_views_to_sample == -1 or num_views_to_sample >= len(index_views):
                return index_views
            if strategy == "random":
                return index_views[torch.randperm(len(index_views))[:num_views_to_sample]]
            elif strategy == "neighbors":
                raise NotImplementedError
                # Choose a random center view and choose views around it, based on cameras extrinsics
                if center_idx is None:
                    center_idx = np.random.choice(
                        len(index_views),
                        size=1,
                        replace=False
                    )[0]
                # Calculate distances to the center view
                rotations = extrinsics[:, :3, :3]  # [V, 3, 3]
                # Calculate camera center as -R^T * t
                translation = extrinsics[:, :3, [3]]  # [V, 3, 1]
                # poses = -rotations.transpose(1, 2) @ translation  # [V, 3, 1]
                poses = translation  # [V, 3, 1]
                center_pose = poses[center_idx]  # [3, 1]
                # Calculate Euclidean distances to the center view
                dists = torch.norm(poses - center_pose.unsqueeze(0), dim=1)[0]  # [V]
                # Calculate angular differences to the center view
                center_rot = extrinsics[center_idx, :3, :3]  # [3, 3]
                # Compute rotation difference
                rot_diffs = torch.matmul(rotations, center_rot.transpose(0, 1))  # [V, 3, 3]
                # Compute angles from rotation matrices
                cos_angles = (rot_diffs[:, 0, 0] + rot_diffs[:, 1, 1] + rot_diffs[:, 2, 2] - 1) / 2  # [V]
                cos_angles = torch.clamp(cos_angles, -1.0, 1.0)  # Numerical stability
                angles = torch.acos(cos_angles)  # [V]
                # Combine distance and angle into a single metric
                combined_metric = dists + angles  # [V]

                # Get the indices of the nearest neighbors
                combined_metric = combined_metric[index_views]
                sorted_indices = torch.argsort(combined_metric)

                return index_views[sorted_indices[:num_views_to_sample]]
            else:
                raise ValueError(f"Unknown sampling strategy: {strategy}")

        index_context = sample_views(extrinsics, context_views, self.cfg.num_context_views,
                                     self.cfg.sample_views_strategy)
        index_target = sample_views(extrinsics, target_views, self.cfg.num_target_views, self.cfg.sample_views_strategy,
                                    center_idx=index_context[0].item())

        return index_context, index_target

    @property
    def num_context_views(self) -> int:
        return self.cfg.num_context_views

    @property
    def num_target_views(self) -> int:
        return self.cfg.num_target_views