# Authors: Bingxin Ke, Haodong Li # Last modified: 2025-05-25 # Note: Add align_depth_median for scale-invariant depth (or distance) evaluation. import numpy as np import torch def align_depth_median( gt_arr: np.ndarray, pred_arr: np.ndarray, valid_mask_arr: np.ndarray, return_scale_shift=True, ): gt_median = np.median(gt_arr[valid_mask_arr]) pred_median = np.median(pred_arr[valid_mask_arr]) scale = gt_median / pred_median shift = 0 aligned_pred = pred_arr * scale + shift if return_scale_shift: return aligned_pred, scale, shift else: return aligned_pred def align_depth_least_square( gt_arr: np.ndarray, pred_arr: np.ndarray, valid_mask_arr: np.ndarray, return_scale_shift=True, max_resolution=None, ): ori_shape = pred_arr.shape gt = gt_arr.squeeze() pred = pred_arr.squeeze() valid_mask = valid_mask_arr.squeeze() # downsample if max_resolution is not None: scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) if scale_factor < 1: downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() valid_mask = ( downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) .bool() .numpy() ) assert ( gt.shape == pred.shape == valid_mask.shape ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" gt_masked = gt[valid_mask].reshape((-1, 1)) pred_masked = pred[valid_mask].reshape((-1, 1)) # numpy solver _ones = np.ones_like(pred_masked) A = np.concatenate([pred_masked, _ones], axis=-1) X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] scale, shift = X aligned_pred = pred_arr * scale + shift # restore dimensions aligned_pred = aligned_pred.reshape(ori_shape) if return_scale_shift: return aligned_pred, scale, shift else: return aligned_pred