| | import json |
| | from dataclasses import asdict, dataclass |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import torch |
| | from einops import rearrange |
| | from lightning.pytorch import LightningModule |
| | from tqdm import tqdm |
| |
|
| | from ..geometry.epipolar_lines import project_rays |
| | from ..geometry.projection import get_world_rays, sample_image_grid |
| | from ..misc.image_io import save_image |
| | from ..visualization.annotation import add_label |
| | from ..visualization.layout import add_border, hcat |
| |
|
| |
|
| | @dataclass |
| | class EvaluationIndexGeneratorCfg: |
| | num_target_views: int |
| | min_distance: int |
| | max_distance: int |
| | min_overlap: float |
| | max_overlap: float |
| | output_path: Path |
| | save_previews: bool |
| | seed: int |
| |
|
| |
|
| | @dataclass |
| | class IndexEntry: |
| | context: tuple[int, ...] |
| | target: tuple[int, ...] |
| | overlap: Optional[str | float] = None |
| |
|
| |
|
| | class EvaluationIndexGenerator(LightningModule): |
| | generator: torch.Generator |
| | cfg: EvaluationIndexGeneratorCfg |
| | index: dict[str, IndexEntry | None] |
| |
|
| | def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| | self.generator = torch.Generator() |
| | self.generator.manual_seed(cfg.seed) |
| | self.index = {} |
| |
|
| | def test_step(self, batch, batch_idx): |
| | b, v, _, h, w = batch["target"]["image"].shape |
| | assert b == 1 |
| | extrinsics = batch["target"]["extrinsics"][0] |
| | intrinsics = batch["target"]["intrinsics"][0] |
| | scene = batch["scene"][0] |
| |
|
| | context_indices = torch.randperm(v, generator=self.generator) |
| | for context_index in tqdm(context_indices, "Finding context pair"): |
| | xy, _ = sample_image_grid((h, w), self.device) |
| | context_origins, context_directions = get_world_rays( |
| | rearrange(xy, "h w xy -> (h w) xy"), |
| | extrinsics[context_index], |
| | intrinsics[context_index], |
| | ) |
| |
|
| | |
| | valid_indices = [] |
| | for step in (1, -1): |
| | min_distance = self.cfg.min_distance |
| | max_distance = self.cfg.max_distance |
| | current_index = context_index + step * min_distance |
| |
|
| | while 0 <= current_index.item() < v: |
| | |
| | current_origins, current_directions = get_world_rays( |
| | rearrange(xy, "h w xy -> (h w) xy"), |
| | extrinsics[current_index], |
| | intrinsics[current_index], |
| | ) |
| | projection_onto_current = project_rays( |
| | context_origins, |
| | context_directions, |
| | extrinsics[current_index], |
| | intrinsics[current_index], |
| | ) |
| | projection_onto_context = project_rays( |
| | current_origins, |
| | current_directions, |
| | extrinsics[context_index], |
| | intrinsics[context_index], |
| | ) |
| | overlap_a = projection_onto_context["overlaps_image"].float().mean() |
| | overlap_b = projection_onto_current["overlaps_image"].float().mean() |
| |
|
| | overlap = min(overlap_a, overlap_b) |
| | delta = (current_index - context_index).abs() |
| |
|
| | min_overlap = self.cfg.min_overlap |
| | max_overlap = self.cfg.max_overlap |
| | if min_overlap <= overlap <= max_overlap: |
| | valid_indices.append( |
| | (current_index.item(), overlap_a, overlap_b) |
| | ) |
| |
|
| | |
| | if overlap < min_overlap or delta > max_distance: |
| | break |
| |
|
| | current_index += step |
| |
|
| | if valid_indices: |
| | |
| | num_options = len(valid_indices) |
| | chosen = torch.randint( |
| | 0, num_options, size=tuple(), generator=self.generator |
| | ) |
| | chosen, overlap_a, overlap_b = valid_indices[chosen] |
| |
|
| | context_left = min(chosen, context_index.item()) |
| | context_right = max(chosen, context_index.item()) |
| | delta = context_right - context_left |
| |
|
| | |
| | while True: |
| | target_views = torch.randint( |
| | context_left, |
| | context_right + 1, |
| | (self.cfg.num_target_views,), |
| | generator=self.generator, |
| | ) |
| | if (target_views.unique(return_counts=True)[1] == 1).all(): |
| | break |
| |
|
| | target = tuple(sorted(target_views.tolist())) |
| | self.index[scene] = IndexEntry( |
| | context=(context_left, context_right), |
| | target=target, |
| | ) |
| |
|
| | |
| | if self.cfg.save_previews: |
| | preview_path = self.cfg.output_path / "previews" |
| | preview_path.mkdir(exist_ok=True, parents=True) |
| | a = batch["target"]["image"][0, chosen] |
| | a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") |
| | b = batch["target"]["image"][0, context_index] |
| | b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") |
| | vis = add_border(add_border(hcat(a, b)), 1, 0) |
| | vis = add_label(vis, f"Distance: {delta} frames") |
| | save_image(add_border(vis), preview_path / f"{scene}.png") |
| | break |
| | else: |
| | |
| | self.index[scene] = None |
| |
|
| | def save_index(self) -> None: |
| | self.cfg.output_path.mkdir(exist_ok=True, parents=True) |
| | with (self.cfg.output_path / "evaluation_index.json").open("w") as f: |
| | json.dump( |
| | {k: None if v is None else asdict(v) for k, v in self.index.items()}, f |
| | ) |
| |
|