| | |
| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | import copy |
| | from jaxtyping import Float, Int64 |
| | from torch import Tensor |
| | import random |
| | from .view_sampler import ViewSampler |
| |
|
| |
|
| | @dataclass |
| | class ViewSamplerRankCfg: |
| | name: Literal["rank"] |
| | num_context_views: int |
| | num_target_views: int |
| | min_distance_between_context_views: int |
| | max_distance_between_context_views: int |
| | min_distance_to_context_views: int |
| | warm_up_steps: int |
| | initial_min_distance_between_context_views: int |
| | initial_max_distance_between_context_views: int |
| | max_img_per_gpu: int |
| |
|
| |
|
| | def rotation_angle(R1, R2): |
| | |
| | R = R1.T @ R2 |
| | |
| | val = (torch.trace(R) - 1) / 2 |
| | val = torch.clamp(val, -1.0, 1.0) |
| | angle_rad = torch.acos(val) |
| | angle_deg = angle_rad * 180 / torch.pi |
| | return angle_deg |
| |
|
| | def extrinsic_distance(extrinsic1, extrinsic2, lambda_t=1.0): |
| | R1, t1 = extrinsic1[:3, :3], extrinsic1[:3, 3] |
| | R2, t2 = extrinsic2[:3, :3], extrinsic2[:3, 3] |
| | rot_diff = rotation_angle(R1, R2) / 180 |
| | |
| | center_diff = torch.norm(t1 - t2) |
| | return rot_diff + lambda_t * center_diff |
| |
|
| | def rotation_angle_batch(R1, R2): |
| | |
| | |
| | |
| | |
| | |
| | |
| | R1_t = R1.transpose(-2, -1)[:, None, :, :] |
| | R2_b = R2[None, :, :, :] |
| | R_mult = torch.matmul(R1_t, R2_b) |
| | |
| | trace_vals = R_mult[..., 0, 0] + R_mult[..., 1, 1] + R_mult[..., 2, 2] |
| | val = (trace_vals - 1) / 2 |
| | val = torch.clamp(val, -1.0, 1.0) |
| | angle_rad = torch.acos(val) |
| | angle_deg = angle_rad * 180 / torch.pi |
| | return angle_deg / 180.0 |
| |
|
| | def extrinsic_distance_batch(extrinsics, lambda_t=1.0): |
| | |
| | |
| | R = extrinsics[:, :3, :3] |
| | t = extrinsics[:, :3, 3] |
| | |
| | rot_diff = rotation_angle_batch(R, R) |
| | |
| | |
| | |
| | t_i = t[:, None, :] |
| | t_j = t[None, :, :] |
| | trans_diff = torch.norm(t_i - t_j, dim=2) |
| | dists = rot_diff + lambda_t * trans_diff |
| | return dists |
| |
|
| |
|
| | def compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True): |
| | |
| | if normalize: |
| | extrinsics = copy.deepcopy(extrinsics) |
| | camera_center = copy.deepcopy(extrinsics[:, :3, 3]) |
| | camera_center_scale = torch.norm(camera_center, dim=1) |
| | avg_scale = torch.mean(camera_center_scale) |
| | extrinsics[:, :3, 3] = extrinsics[:, :3, 3] / avg_scale |
| |
|
| | |
| | if batched: |
| | dists = extrinsic_distance_batch(extrinsics, lambda_t=lambda_t) |
| | else: |
| | N = extrinsics.shape[0] |
| | dists = torch.zeros((N, N), device=extrinsics.device) |
| | for i in range(N): |
| | for j in range(N): |
| | dists[i,j] = extrinsic_distance(extrinsics[i], extrinsics[j], lambda_t=lambda_t) |
| | ranking = torch.argsort(dists, dim=1) |
| | return ranking, dists |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]): |
| | |
| | def sample( |
| | self, |
| | scene: str, |
| | num_context_views: int, |
| | extrinsics: Float[Tensor, "view 4 4"], |
| | intrinsics: Float[Tensor, "view 3 3"], |
| | device: torch.device = torch.device("cpu"), |
| | ) -> tuple[ |
| | Int64[Tensor, " context_view"], |
| | Int64[Tensor, " target_view"], |
| | Float[Tensor, " overlap"], |
| | ]: |
| | num_views, _, _ = extrinsics.shape |
| | |
| | extrinsics = extrinsics.clone() |
| | |
| | ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True) |
| | reference_view = random.sample(range(num_views), 1)[0] |
| | |
| | refview_ranking = ranking[reference_view] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views] |
| |
|
| | |
| | |
| | |
| | max_gap = min(max_gap, num_views-1) |
| | |
| | index_context_left = reference_view |
| | rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0] |
| | |
| | |
| | |
| |
|
| | index_context_right = refview_ranking[rightmost_index].item() |
| | |
| | middle_indices = refview_ranking[1: rightmost_index].tolist() |
| | index_target = random.sample(middle_indices, self.num_target_views) |
| | |
| | remaining_indices = [idx for idx in middle_indices if idx not in index_target] |
| | |
| | |
| | extra_views = [] |
| | num_extra_views = num_context_views - 2 |
| | if num_extra_views > 0 and remaining_indices: |
| | extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices))) |
| | else: |
| | extra_views = [] |
| | |
| | overlap = torch.zeros(1) |
| |
|
| | return ( |
| | torch.tensor((index_context_left, *extra_views, index_context_right)), |
| | torch.tensor(index_target), |
| | overlap |
| | ) |
| | |
| |
|
| | @property |
| | def num_context_views(self) -> int: |
| | return self.cfg.num_context_views |
| |
|
| | @property |
| | def num_target_views(self) -> int: |
| | return self.cfg.num_target_views |
| | |
| | @property |
| | def num_ctxt_gap_mapping_target(self) -> dict: |
| | mapping = dict() |
| | for num_ctxt in range(2, self.cfg.num_context_views + 1): |
| | mapping[num_ctxt] = [max(num_ctxt * 2, self.cfg.num_target_views + num_ctxt), max(self.cfg.num_target_views + num_ctxt, min(num_ctxt ** 2, self.cfg.max_distance_between_context_views))] |
| | return mapping |
| | |
| | @property |
| | def num_ctxt_gap_mapping(self) -> dict: |
| | mapping = dict() |
| | for num_ctxt in range(2, self.cfg.num_context_views + 1): |
| | mapping[num_ctxt] = [min(num_ctxt * 3, self.cfg.min_distance_between_context_views), min(max(num_ctxt * 5, num_ctxt ** 2), self.cfg.max_distance_between_context_views)] |
| | return mapping |
| |
|
| | |
| |
|