| import torch |
| from einops import einsum, reduce, repeat |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from ..types import BatchedExample |
|
|
|
|
| def compute_depth_for_disparity( |
| extrinsics: Float[Tensor, "batch view 4 4"], |
| intrinsics: Float[Tensor, "batch view 3 3"], |
| image_shape: tuple[int, int], |
| disparity: float, |
| delta_min: float = 1e-6, |
| ) -> Float[Tensor, " batch"]: |
| """Compute the depth at which moving the maximum distance between cameras |
| corresponds to the specified disparity (in pixels). |
| """ |
|
|
| |
| origins = extrinsics[:, :, :3, 3] |
| deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) |
| deltas = deltas.clip(min=delta_min) |
| baselines = reduce(deltas, "b v ov -> b", "max") |
|
|
| |
| h, w = image_shape |
| pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) |
| pixel_size = einsum( |
| intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" |
| ) |
|
|
| |
| |
| mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") |
|
|
| return baselines / (disparity * mean_pixel_size) |
|
|
|
|
| def apply_bounds_shim( |
| batch: BatchedExample, |
| near_disparity: float, |
| far_disparity: float, |
| ) -> BatchedExample: |
| """Compute reasonable near and far planes (lower and upper bounds on depth). This |
| assumes that all of an example's views are of roughly the same thing. |
| """ |
|
|
| context = batch["context"] |
| _, cv, _, h, w = context["image"].shape |
| |
| |
| near = compute_depth_for_disparity( |
| context["extrinsics"], |
| context["intrinsics"], |
| (h, w), |
| near_disparity, |
| ) |
| far = compute_depth_for_disparity( |
| context["extrinsics"], |
| context["intrinsics"], |
| (h, w), |
| far_disparity, |
| ) |
|
|
| target = batch["target"] |
| _, tv, _, _, _ = target["image"].shape |
| return { |
| **batch, |
| "context": { |
| **context, |
| "near": repeat(near, "b -> b v", v=cv), |
| "far": repeat(far, "b -> b v", v=cv), |
| }, |
| "target": { |
| **target, |
| "near": repeat(near, "b -> b v", v=tv), |
| "far": repeat(far, "b -> b v", v=tv), |
| }, |
| } |
|
|