Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import einsum, reduce, repeat | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from ..types import BatchedExample | |
| def compute_depth_for_disparity( | |
| extrinsics: Float[Tensor, "batch view 4 4"], | |
| intrinsics: Float[Tensor, "batch view 3 3"], | |
| image_shape: tuple[int, int], | |
| disparity: float, | |
| delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. | |
| ) -> Float[Tensor, " batch"]: | |
| """Compute the depth at which moving the maximum distance between cameras | |
| corresponds to the specified disparity (in pixels). | |
| """ | |
| # Use the furthest distance between cameras as the baseline. | |
| origins = extrinsics[:, :, :3, 3] | |
| deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) | |
| deltas = deltas.clip(min=delta_min) | |
| baselines = reduce(deltas, "b v ov -> b", "max") | |
| # Compute a single pixel's size at depth 1. | |
| h, w = image_shape | |
| pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) | |
| pixel_size = einsum( | |
| intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" | |
| ) | |
| # This wouldn't make sense with non-square pixels, but then again, non-square pixels | |
| # don't make much sense anyway. | |
| mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") | |
| return baselines / (disparity * mean_pixel_size) | |
| def apply_bounds_shim( | |
| batch: BatchedExample, | |
| near_disparity: float, | |
| far_disparity: float, | |
| ) -> BatchedExample: | |
| """Compute reasonable near and far planes (lower and upper bounds on depth). This | |
| assumes that all of an example's views are of roughly the same thing. | |
| """ | |
| context = batch["context"] | |
| _, cv, _, h, w = context["image"].shape | |
| # Compute near and far planes using the context views. | |
| near = compute_depth_for_disparity( | |
| context["extrinsics"], | |
| context["intrinsics"], | |
| (h, w), | |
| near_disparity, | |
| ) | |
| far = compute_depth_for_disparity( | |
| context["extrinsics"], | |
| context["intrinsics"], | |
| (h, w), | |
| far_disparity, | |
| ) | |
| target = batch["target"] | |
| _, tv, _, _, _ = target["image"].shape | |
| return { | |
| **batch, | |
| "context": { | |
| **context, | |
| "near": repeat(near, "b -> b v", v=cv), | |
| "far": repeat(far, "b -> b v", v=cv), | |
| }, | |
| "target": { | |
| **target, | |
| "near": repeat(near, "b -> b v", v=tv), | |
| "far": repeat(far, "b -> b v", v=tv), | |
| }, | |
| } | |