linhaotong
update
b9f87ab
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Alignment utilities for depth estimation and metric scaling.
"""
from typing import Tuple
import torch
def least_squares_scale_scalar(
a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12
) -> torch.Tensor:
"""
Compute least squares scale factor s such that a ≈ s * b.
Args:
a: First tensor
b: Second tensor
eps: Small epsilon for numerical stability
Returns:
Scalar tensor containing the scale factor
Raises:
ValueError: If tensors have mismatched shapes or devices
TypeError: If tensors are not floating point
"""
if a.shape != b.shape:
raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")
if a.device != b.device:
raise ValueError(f"Device mismatch: {a.device} vs {b.device}")
if not a.is_floating_point() or not b.is_floating_point():
raise TypeError("Tensors must be floating point type")
# Compute dot products for least squares solution
num = torch.dot(a.reshape(-1), b.reshape(-1))
den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps)
return num / den
def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
"""
Compute non-sky mask from sky prediction.
Args:
sky_prediction: Sky prediction tensor
threshold: Threshold for sky classification
Returns:
Boolean mask where True indicates non-sky regions
"""
return sky_prediction < threshold
def compute_alignment_mask(
depth_conf: torch.Tensor,
non_sky_mask: torch.Tensor,
depth: torch.Tensor,
metric_depth: torch.Tensor,
median_conf: torch.Tensor,
min_depth_threshold: float = 1e-3,
min_metric_depth_threshold: float = 1e-2,
) -> torch.Tensor:
"""
Compute mask for depth alignment based on confidence and depth thresholds.
Args:
depth_conf: Depth confidence tensor
non_sky_mask: Non-sky region mask
depth: Predicted depth tensor
metric_depth: Metric depth tensor
median_conf: Median confidence threshold
min_depth_threshold: Minimum depth threshold
min_metric_depth_threshold: Minimum metric depth threshold
Returns:
Boolean mask for valid alignment regions
"""
return (
(depth_conf >= median_conf)
& non_sky_mask
& (metric_depth > min_metric_depth_threshold)
& (depth > min_depth_threshold)
)
def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor:
"""
Sample tensor elements for quantile computation to reduce memory usage.
Args:
tensor: Input tensor to sample
max_samples: Maximum number of samples to take
Returns:
Sampled tensor
"""
if tensor.numel() <= max_samples:
return tensor
idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples]
return tensor.flatten()[idx]
def apply_metric_scaling(
depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0
) -> torch.Tensor:
"""
Apply metric scaling to depth based on camera intrinsics.
Args:
depth: Input depth tensor
intrinsics: Camera intrinsics tensor
scale_factor: Scaling factor for metric conversion
Returns:
Scaled depth tensor
"""
focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2
return depth * (focal_length[:, :, None, None] / scale_factor)
def set_sky_regions_to_max_depth(
depth: torch.Tensor,
depth_conf: torch.Tensor,
non_sky_mask: torch.Tensor,
max_depth: float = 200.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Set sky regions to maximum depth and high confidence.
Args:
depth: Depth tensor
depth_conf: Depth confidence tensor
non_sky_mask: Non-sky region mask
max_depth: Maximum depth value for sky regions
Returns:
Tuple of (updated_depth, updated_depth_conf)
"""
depth = depth.clone()
depth_conf = depth_conf.clone()
# Set sky regions to max depth and high confidence
depth[~non_sky_mask] = max_depth
depth_conf[~non_sky_mask] = 1.0
return depth, depth_conf