File size: 7,694 Bytes
18b382b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Reference View Selection Strategies

This module provides different strategies for selecting a reference view
from multiple input views in multi-view depth estimation.
"""

from typing import Literal

import torch

RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]


def select_reference_view(
    x: torch.Tensor,
    strategy: RefViewStrategy = "saddle_balanced",
) -> torch.Tensor:
    """
    Select a reference view from multiple views using the specified strategy.

    Args:
        x: Input tensor of shape (B, S, N, C) where
           B = batch size
           S = number of views
           N = number of tokens
           C = channel dimension
        strategy: Selection strategy, one of:
            - "first": Always select the first view
            - "middle": Select the middle view
            - "saddle_balanced": Select view with balanced features across multiple metrics
            - "saddle_sim_range": Select view with largest similarity range

    Returns:
        b_idx: Tensor of shape (B,) containing the selected view index for each batch
    """
    B, S, N, C = x.shape

    # For single view, no reordering needed
    if S <= 1:
        return torch.zeros(B, dtype=torch.long, device=x.device)

    # Simple position-based strategies
    if strategy == "first":
        return torch.zeros(B, dtype=torch.long, device=x.device)

    elif strategy == "middle":
        return torch.full((B,), S // 2, dtype=torch.long, device=x.device)

    # Feature-based strategies require normalized class tokens
    # Extract and normalize class tokens (first token of each view)
    img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True)  # B S C

    if strategy == "saddle_balanced":
        # Select view with balanced features across multiple metrics
        # Compute similarity matrix
        sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))  # B S S
        sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
        sim_score = sim_no_diag.sum(dim=-1) / (S - 1)  # B S

        feat_norm = x[:, :, 0].norm(dim=-1)  # B S
        feat_var = img_class_feat.var(dim=-1)  # B S

        # Normalize all metrics to [0, 1]
        def normalize_metric(metric):
            min_val = metric.min(dim=1, keepdim=True).values
            max_val = metric.max(dim=1, keepdim=True).values
            return (metric - min_val) / (max_val - min_val + 1e-8)

        sim_score_norm = normalize_metric(sim_score)
        norm_norm = normalize_metric(feat_norm)
        var_norm = normalize_metric(feat_var)

        # Select view closest to the median (0.5) across all metrics
        balance_score = (
            (sim_score_norm - 0.5).abs() +
            (norm_norm - 0.5).abs() +
            (var_norm - 0.5).abs()
        )
        b_idx = balance_score.argmin(dim=1)

    elif strategy == "saddle_sim_range":
        # Select view with largest similarity range (max - min)
        sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))  # B S S
        sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)

        sim_max = sim_no_diag.max(dim=-1).values  # B S
        sim_min = sim_no_diag.min(dim=-1).values  # B S
        sim_range = sim_max - sim_min
        b_idx = sim_range.argmax(dim=1)

    else:
        raise ValueError(
            f"Unknown reference view selection strategy: {strategy}. "
            f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
        )

    return b_idx


def reorder_by_reference(
    x: torch.Tensor,
    b_idx: torch.Tensor,
) -> torch.Tensor:
    """
    Reorder views to place the selected reference view first.

    Args:
        x: Input tensor of shape (B, S, N, C)
        b_idx: Reference view indices of shape (B,)

    Returns:
        Reordered tensor with reference view at position 0

    Example:
        If b_idx = [2] and S = 5 (views [0,1,2,3,4]),
        result order is [2,0,1,3,4] (ref_idx first, then others in order)
    """
    B, S = x.shape[0], x.shape[1]

    # For single view, no reordering needed
    if S <= 1:
        return x

    # Create position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
    positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)  # B S

    # For each position, determine which original index it should take
    # Position 0 gets ref_idx
    # Position 1 to ref_idx gets indices 0 to ref_idx-1
    # Position ref_idx+1 to S-1 gets indices ref_idx+1 to S-1

    b_idx_expanded = b_idx.unsqueeze(1)  # B 1

    # Create the reordering indices
    # For positions 1 to ref_idx: map to indices 0 to ref_idx-1 (shift by -1)
    # For positions > ref_idx: keep the same
    reorder_indices = positions.clone()
    reorder_indices = torch.where(
        (positions > 0) & (positions <= b_idx_expanded),
        positions - 1,
        positions
    )
    # Set position 0 to ref_idx
    reorder_indices[:, 0] = b_idx

    # Gather using advanced indexing
    batch_indices = torch.arange(B, device=x.device).unsqueeze(1)  # B 1
    x_reordered = x[batch_indices, reorder_indices]

    return x_reordered


def restore_original_order(
    x: torch.Tensor,
    b_idx: torch.Tensor,
) -> torch.Tensor:
    """
    Restore original view order after processing.

    Args:
        x: Reordered tensor of shape (B, S, ...)
        b_idx: Original reference view indices of shape (B,)

    Returns:
        Tensor with original view order restored

    Example:
        If original order was [0, 1, 2, 3, 4] and b_idx=2,
        reordered becomes [2, 0, 1, 3, 4] (reference at position 0),
        restore should return [0, 1, 2, 3, 4] (original order).
    """
    B, S = x.shape[0], x.shape[1]

    # For single view, no restoration needed
    if S <= 1:
        return x

    # Create target position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
    target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)  # B S

    # For each target position, determine which current position it comes from
    # Target position 0 to ref_idx-1 <- Current position 1 to ref_idx (shift by +1)
    # Target position ref_idx <- Current position 0
    # Target position ref_idx+1 to S-1 <- Current position ref_idx+1 to S-1 (no change)

    b_idx_expanded = b_idx.unsqueeze(1)  # B 1

    # Create the restore indices
    restore_indices = torch.where(
        target_positions < b_idx_expanded,
        target_positions + 1,  # Positions before ref_idx come from current position + 1
        target_positions        # Positions after ref_idx stay the same
    )
    # Target position = ref_idx comes from current position 0
    # Use scatter to set specific positions
    restore_indices = torch.scatter(
        restore_indices,
        dim=1,
        index=b_idx_expanded,
        src=torch.zeros_like(b_idx_expanded)
    )

    # Gather using advanced indexing
    batch_indices = torch.arange(B, device=x.device).unsqueeze(1)  # B 1
    x_restored = x[batch_indices, restore_indices]

    return x_restored