phiph's picture
Upload folder using huggingface_hub
7382c66 verified
# Authors: Bingxin Ke, Haodong Li
# Last modified: 2025-05-25
# Note: Add align_depth_median for scale-invariant depth (or distance) evaluation.
import numpy as np
import torch
def align_depth_median(
gt_arr: np.ndarray,
pred_arr: np.ndarray,
valid_mask_arr: np.ndarray,
return_scale_shift=True,
):
gt_median = np.median(gt_arr[valid_mask_arr])
pred_median = np.median(pred_arr[valid_mask_arr])
scale = gt_median / pred_median
shift = 0
aligned_pred = pred_arr * scale + shift
if return_scale_shift:
return aligned_pred, scale, shift
else:
return aligned_pred
def align_depth_least_square(
gt_arr: np.ndarray,
pred_arr: np.ndarray,
valid_mask_arr: np.ndarray,
return_scale_shift=True,
max_resolution=None,
):
ori_shape = pred_arr.shape
gt = gt_arr.squeeze()
pred = pred_arr.squeeze()
valid_mask = valid_mask_arr.squeeze()
# downsample
if max_resolution is not None:
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
valid_mask = (
downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
.bool()
.numpy()
)
assert (
gt.shape == pred.shape == valid_mask.shape
), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
gt_masked = gt[valid_mask].reshape((-1, 1))
pred_masked = pred[valid_mask].reshape((-1, 1))
# numpy solver
_ones = np.ones_like(pred_masked)
A = np.concatenate([pred_masked, _ones], axis=-1)
X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
scale, shift = X
aligned_pred = pred_arr * scale + shift
# restore dimensions
aligned_pred = aligned_pred.reshape(ori_shape)
if return_scale_shift:
return aligned_pred, scale, shift
else:
return aligned_pred