File size: 7,694 Bytes
18b382b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference View Selection Strategies
This module provides different strategies for selecting a reference view
from multiple input views in multi-view depth estimation.
"""
from typing import Literal
import torch
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
def select_reference_view(
x: torch.Tensor,
strategy: RefViewStrategy = "saddle_balanced",
) -> torch.Tensor:
"""
Select a reference view from multiple views using the specified strategy.
Args:
x: Input tensor of shape (B, S, N, C) where
B = batch size
S = number of views
N = number of tokens
C = channel dimension
strategy: Selection strategy, one of:
- "first": Always select the first view
- "middle": Select the middle view
- "saddle_balanced": Select view with balanced features across multiple metrics
- "saddle_sim_range": Select view with largest similarity range
Returns:
b_idx: Tensor of shape (B,) containing the selected view index for each batch
"""
B, S, N, C = x.shape
# For single view, no reordering needed
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
# Simple position-based strategies
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
elif strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies require normalized class tokens
# Extract and normalize class tokens (first token of each view)
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # B S C
if strategy == "saddle_balanced":
# Select view with balanced features across multiple metrics
# Compute similarity matrix
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # B S
feat_norm = x[:, :, 0].norm(dim=-1) # B S
feat_var = img_class_feat.var(dim=-1) # B S
# Normalize all metrics to [0, 1]
def normalize_metric(metric):
min_val = metric.min(dim=1, keepdim=True).values
max_val = metric.max(dim=1, keepdim=True).values
return (metric - min_val) / (max_val - min_val + 1e-8)
sim_score_norm = normalize_metric(sim_score)
norm_norm = normalize_metric(feat_norm)
var_norm = normalize_metric(feat_var)
# Select view closest to the median (0.5) across all metrics
balance_score = (
(sim_score_norm - 0.5).abs() +
(norm_norm - 0.5).abs() +
(var_norm - 0.5).abs()
)
b_idx = balance_score.argmin(dim=1)
elif strategy == "saddle_sim_range":
# Select view with largest similarity range (max - min)
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values # B S
sim_min = sim_no_diag.min(dim=-1).values # B S
sim_range = sim_max - sim_min
b_idx = sim_range.argmax(dim=1)
else:
raise ValueError(
f"Unknown reference view selection strategy: {strategy}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
return b_idx
def reorder_by_reference(
x: torch.Tensor,
b_idx: torch.Tensor,
) -> torch.Tensor:
"""
Reorder views to place the selected reference view first.
Args:
x: Input tensor of shape (B, S, N, C)
b_idx: Reference view indices of shape (B,)
Returns:
Reordered tensor with reference view at position 0
Example:
If b_idx = [2] and S = 5 (views [0,1,2,3,4]),
result order is [2,0,1,3,4] (ref_idx first, then others in order)
"""
B, S = x.shape[0], x.shape[1]
# For single view, no reordering needed
if S <= 1:
return x
# Create position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S
# For each position, determine which original index it should take
# Position 0 gets ref_idx
# Position 1 to ref_idx gets indices 0 to ref_idx-1
# Position ref_idx+1 to S-1 gets indices ref_idx+1 to S-1
b_idx_expanded = b_idx.unsqueeze(1) # B 1
# Create the reordering indices
# For positions 1 to ref_idx: map to indices 0 to ref_idx-1 (shift by -1)
# For positions > ref_idx: keep the same
reorder_indices = positions.clone()
reorder_indices = torch.where(
(positions > 0) & (positions <= b_idx_expanded),
positions - 1,
positions
)
# Set position 0 to ref_idx
reorder_indices[:, 0] = b_idx
# Gather using advanced indexing
batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1
x_reordered = x[batch_indices, reorder_indices]
return x_reordered
def restore_original_order(
x: torch.Tensor,
b_idx: torch.Tensor,
) -> torch.Tensor:
"""
Restore original view order after processing.
Args:
x: Reordered tensor of shape (B, S, ...)
b_idx: Original reference view indices of shape (B,)
Returns:
Tensor with original view order restored
Example:
If original order was [0, 1, 2, 3, 4] and b_idx=2,
reordered becomes [2, 0, 1, 3, 4] (reference at position 0),
restore should return [0, 1, 2, 3, 4] (original order).
"""
B, S = x.shape[0], x.shape[1]
# For single view, no restoration needed
if S <= 1:
return x
# Create target position indices: (B, S) where each row is [0, 1, 2, ..., S-1]
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S
# For each target position, determine which current position it comes from
# Target position 0 to ref_idx-1 <- Current position 1 to ref_idx (shift by +1)
# Target position ref_idx <- Current position 0
# Target position ref_idx+1 to S-1 <- Current position ref_idx+1 to S-1 (no change)
b_idx_expanded = b_idx.unsqueeze(1) # B 1
# Create the restore indices
restore_indices = torch.where(
target_positions < b_idx_expanded,
target_positions + 1, # Positions before ref_idx come from current position + 1
target_positions # Positions after ref_idx stay the same
)
# Target position = ref_idx comes from current position 0
# Use scatter to set specific positions
restore_indices = torch.scatter(
restore_indices,
dim=1,
index=b_idx_expanded,
src=torch.zeros_like(b_idx_expanded)
)
# Gather using advanced indexing
batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1
x_restored = x[batch_indices, restore_indices]
return x_restored
|