Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Alignment utilities for depth estimation and metric scaling. | |
| """ | |
| from typing import Tuple | |
| import torch | |
| def least_squares_scale_scalar( | |
| a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12 | |
| ) -> torch.Tensor: | |
| """ | |
| Compute least squares scale factor s such that a ≈ s * b. | |
| Args: | |
| a: First tensor | |
| b: Second tensor | |
| eps: Small epsilon for numerical stability | |
| Returns: | |
| Scalar tensor containing the scale factor | |
| Raises: | |
| ValueError: If tensors have mismatched shapes or devices | |
| TypeError: If tensors are not floating point | |
| """ | |
| if a.shape != b.shape: | |
| raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}") | |
| if a.device != b.device: | |
| raise ValueError(f"Device mismatch: {a.device} vs {b.device}") | |
| if not a.is_floating_point() or not b.is_floating_point(): | |
| raise TypeError("Tensors must be floating point type") | |
| # Compute dot products for least squares solution | |
| num = torch.dot(a.reshape(-1), b.reshape(-1)) | |
| den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps) | |
| return num / den | |
| def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: | |
| """ | |
| Compute non-sky mask from sky prediction. | |
| Args: | |
| sky_prediction: Sky prediction tensor | |
| threshold: Threshold for sky classification | |
| Returns: | |
| Boolean mask where True indicates non-sky regions | |
| """ | |
| return sky_prediction < threshold | |
| def compute_alignment_mask( | |
| depth_conf: torch.Tensor, | |
| non_sky_mask: torch.Tensor, | |
| depth: torch.Tensor, | |
| metric_depth: torch.Tensor, | |
| median_conf: torch.Tensor, | |
| min_depth_threshold: float = 1e-3, | |
| min_metric_depth_threshold: float = 1e-2, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute mask for depth alignment based on confidence and depth thresholds. | |
| Args: | |
| depth_conf: Depth confidence tensor | |
| non_sky_mask: Non-sky region mask | |
| depth: Predicted depth tensor | |
| metric_depth: Metric depth tensor | |
| median_conf: Median confidence threshold | |
| min_depth_threshold: Minimum depth threshold | |
| min_metric_depth_threshold: Minimum metric depth threshold | |
| Returns: | |
| Boolean mask for valid alignment regions | |
| """ | |
| return ( | |
| (depth_conf >= median_conf) | |
| & non_sky_mask | |
| & (metric_depth > min_metric_depth_threshold) | |
| & (depth > min_depth_threshold) | |
| ) | |
| def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor: | |
| """ | |
| Sample tensor elements for quantile computation to reduce memory usage. | |
| Args: | |
| tensor: Input tensor to sample | |
| max_samples: Maximum number of samples to take | |
| Returns: | |
| Sampled tensor | |
| """ | |
| if tensor.numel() <= max_samples: | |
| return tensor | |
| idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples] | |
| return tensor.flatten()[idx] | |
| def apply_metric_scaling( | |
| depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0 | |
| ) -> torch.Tensor: | |
| """ | |
| Apply metric scaling to depth based on camera intrinsics. | |
| Args: | |
| depth: Input depth tensor | |
| intrinsics: Camera intrinsics tensor | |
| scale_factor: Scaling factor for metric conversion | |
| Returns: | |
| Scaled depth tensor | |
| """ | |
| focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2 | |
| return depth * (focal_length[:, :, None, None] / scale_factor) | |
| def set_sky_regions_to_max_depth( | |
| depth: torch.Tensor, | |
| depth_conf: torch.Tensor, | |
| non_sky_mask: torch.Tensor, | |
| max_depth: float = 200.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Set sky regions to maximum depth and high confidence. | |
| Args: | |
| depth: Depth tensor | |
| depth_conf: Depth confidence tensor | |
| non_sky_mask: Non-sky region mask | |
| max_depth: Maximum depth value for sky regions | |
| Returns: | |
| Tuple of (updated_depth, updated_depth_conf) | |
| """ | |
| depth = depth.clone() | |
| depth_conf = depth_conf.clone() | |
| # Set sky regions to max depth and high confidence | |
| depth[~non_sky_mask] = max_depth | |
| depth_conf[~non_sky_mask] = 1.0 | |
| return depth, depth_conf | |