| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Alignment utilities for depth estimation and metric scaling. |
| """ |
|
|
| from typing import Tuple |
| import torch |
|
|
|
|
| def least_squares_scale_scalar( |
| a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12 |
| ) -> torch.Tensor: |
| """ |
| Compute least squares scale factor s such that a ≈ s * b. |
| |
| Args: |
| a: First tensor |
| b: Second tensor |
| eps: Small epsilon for numerical stability |
| |
| Returns: |
| Scalar tensor containing the scale factor |
| |
| Raises: |
| ValueError: If tensors have mismatched shapes or devices |
| TypeError: If tensors are not floating point |
| """ |
| if a.shape != b.shape: |
| raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}") |
| if a.device != b.device: |
| raise ValueError(f"Device mismatch: {a.device} vs {b.device}") |
| if not a.is_floating_point() or not b.is_floating_point(): |
| raise TypeError("Tensors must be floating point type") |
|
|
| |
| num = torch.dot(a.reshape(-1), b.reshape(-1)) |
| den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps) |
| return num / den |
|
|
|
|
| def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: |
| """ |
| Compute non-sky mask from sky prediction. |
| |
| Args: |
| sky_prediction: Sky prediction tensor |
| threshold: Threshold for sky classification |
| |
| Returns: |
| Boolean mask where True indicates non-sky regions |
| """ |
| return sky_prediction < threshold |
|
|
|
|
| def compute_alignment_mask( |
| depth_conf: torch.Tensor, |
| non_sky_mask: torch.Tensor, |
| depth: torch.Tensor, |
| metric_depth: torch.Tensor, |
| median_conf: torch.Tensor, |
| min_depth_threshold: float = 1e-3, |
| min_metric_depth_threshold: float = 1e-2, |
| ) -> torch.Tensor: |
| """ |
| Compute mask for depth alignment based on confidence and depth thresholds. |
| |
| Args: |
| depth_conf: Depth confidence tensor |
| non_sky_mask: Non-sky region mask |
| depth: Predicted depth tensor |
| metric_depth: Metric depth tensor |
| median_conf: Median confidence threshold |
| min_depth_threshold: Minimum depth threshold |
| min_metric_depth_threshold: Minimum metric depth threshold |
| |
| Returns: |
| Boolean mask for valid alignment regions |
| """ |
| return ( |
| (depth_conf >= median_conf) |
| & non_sky_mask |
| & (metric_depth > min_metric_depth_threshold) |
| & (depth > min_depth_threshold) |
| ) |
|
|
|
|
| def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor: |
| """ |
| Sample tensor elements for quantile computation to reduce memory usage. |
| |
| Args: |
| tensor: Input tensor to sample |
| max_samples: Maximum number of samples to take |
| |
| Returns: |
| Sampled tensor |
| """ |
| if tensor.numel() <= max_samples: |
| return tensor |
|
|
| idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples] |
| return tensor.flatten()[idx] |
|
|
|
|
| def apply_metric_scaling( |
| depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0 |
| ) -> torch.Tensor: |
| """ |
| Apply metric scaling to depth based on camera intrinsics. |
| |
| Args: |
| depth: Input depth tensor |
| intrinsics: Camera intrinsics tensor |
| scale_factor: Scaling factor for metric conversion |
| |
| Returns: |
| Scaled depth tensor |
| """ |
| focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2 |
| return depth * (focal_length[:, :, None, None] / scale_factor) |
|
|
|
|
| def set_sky_regions_to_max_depth( |
| depth: torch.Tensor, |
| depth_conf: torch.Tensor, |
| non_sky_mask: torch.Tensor, |
| max_depth: float = 200.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Set sky regions to maximum depth and high confidence. |
| |
| Args: |
| depth: Depth tensor |
| depth_conf: Depth confidence tensor |
| non_sky_mask: Non-sky region mask |
| max_depth: Maximum depth value for sky regions |
| |
| Returns: |
| Tuple of (updated_depth, updated_depth_conf) |
| """ |
| depth = depth.clone() |
|
|
| |
| depth[~non_sky_mask] = torch.as_tensor(max_depth, device=depth.device, dtype=depth.dtype) |
| if depth_conf is not None: |
| depth_conf = depth_conf.clone() |
| depth_conf[~non_sky_mask] = torch.as_tensor(1.0, device=depth_conf.device, dtype=depth_conf.dtype) |
| return depth, depth_conf |
| else: |
| return depth, None |
|
|