Spaces:
Running on Zero
Running on Zero
| # Author: Bingxin Ke | |
| # Last modified: 2024-01-11 | |
| import numpy as np | |
| import torch | |
| def align_depth_least_square_video( | |
| gt_arr: np.ndarray, | |
| pred_arr: np.ndarray, | |
| valid_mask_arr: np.ndarray, | |
| return_scale_shift=True, | |
| max_resolution=None, | |
| ): | |
| """ | |
| gt_arr, pred_arr, valid_mask_arr: shape can be (T, H, W) or (T, 1, H, W) | |
| """ | |
| ori_shape = pred_arr.shape | |
| squeeze = lambda x: x.squeeze() # handle (T,1,H,W) -> (T,H,W) | |
| gt = squeeze(gt_arr) | |
| pred = squeeze(pred_arr) | |
| valid_mask = squeeze(valid_mask_arr) | |
| # ----------------------------- | |
| # Optional downsampling (applied per-frame identically) | |
| # ----------------------------- | |
| if max_resolution is not None: | |
| H, W = gt.shape[-2:] | |
| scale_factor = np.min(max_resolution / np.array([H, W])) | |
| if scale_factor < 1: | |
| downscaler = torch.nn.Upsample(scale_factor=float(scale_factor), mode="nearest") | |
| gt = downscaler(torch.as_tensor(gt).unsqueeze(1)).squeeze(1).numpy() | |
| pred = downscaler(torch.as_tensor(pred).unsqueeze(1)).squeeze(1).numpy() | |
| valid_mask = ( | |
| downscaler(torch.as_tensor(valid_mask).unsqueeze(1).float()) | |
| .squeeze(1).bool().numpy() | |
| ) | |
| assert gt.shape == pred.shape == valid_mask.shape, f"{gt.shape}, {pred.shape}, {valid_mask.shape}" | |
| # ----------------------------- | |
| # Flatten ALL frames | |
| # ----------------------------- | |
| gt_masked = gt[valid_mask].reshape(-1, 1) # (N, 1) | |
| pred_masked = pred[valid_mask].reshape(-1, 1) # (N, 1) | |
| # ----------------------------- | |
| # Solve least squares over ALL pixels (T*H*W) | |
| # ----------------------------- | |
| _ones = np.ones_like(pred_masked) | |
| A = np.concatenate([pred_masked, _ones], axis=-1) # (N, 2) | |
| X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] | |
| scale, shift = X | |
| # Apply to original resolution (not the downsampled) | |
| aligned_pred = pred_arr * scale + shift | |
| aligned_pred = aligned_pred.reshape(ori_shape) | |
| if return_scale_shift: | |
| return aligned_pred, scale, shift | |
| else: | |
| return aligned_pred | |
| def align_depth_least_square( | |
| gt_arr: np.ndarray, | |
| pred_arr: np.ndarray, | |
| valid_mask_arr: np.ndarray, | |
| return_scale_shift=True, | |
| max_resolution=None, | |
| ): | |
| ori_shape = pred_arr.shape # input shape | |
| gt = gt_arr.squeeze() # [H, W] | |
| pred = pred_arr.squeeze() | |
| valid_mask = valid_mask_arr.squeeze() | |
| # Downsample | |
| if max_resolution is not None: | |
| scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) | |
| if scale_factor < 1: | |
| downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") | |
| gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() | |
| pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() | |
| valid_mask = ( | |
| downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) | |
| .bool() | |
| .numpy() | |
| ) | |
| assert ( | |
| gt.shape == pred.shape == valid_mask.shape | |
| ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" | |
| gt_masked = gt[valid_mask].reshape((-1, 1)) | |
| pred_masked = pred[valid_mask].reshape((-1, 1)) | |
| # numpy solver | |
| _ones = np.ones_like(pred_masked) | |
| A = np.concatenate([pred_masked, _ones], axis=-1) | |
| X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] | |
| scale, shift = X | |
| aligned_pred = pred_arr * scale + shift | |
| # restore dimensions | |
| aligned_pred = aligned_pred.reshape(ori_shape) | |
| if return_scale_shift: | |
| return aligned_pred, scale, shift | |
| else: | |
| return aligned_pred | |
| # ******************** disparity space ******************** | |
| def depth2disparity(depth, return_mask=False): | |
| if isinstance(depth, torch.Tensor): | |
| disparity = torch.zeros_like(depth) | |
| elif isinstance(depth, np.ndarray): | |
| disparity = np.zeros_like(depth) | |
| non_negtive_mask = depth > 0 | |
| disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] | |
| if return_mask: | |
| return disparity, non_negtive_mask | |
| else: | |
| return disparity | |
| def disparity2depth(disparity, **kwargs): | |
| return depth2disparity(disparity, **kwargs) | |