File size: 4,224 Bytes
4b35c4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Author: Bingxin Ke
# Last modified: 2024-01-11

import numpy as np
import torch
def align_depth_least_square_video(
    gt_arr: np.ndarray,
    pred_arr: np.ndarray,
    valid_mask_arr: np.ndarray,
    return_scale_shift=True,
    max_resolution=None,
):
    """
    gt_arr, pred_arr, valid_mask_arr: shape can be (T, H, W) or (T, 1, H, W)
    """
    ori_shape = pred_arr.shape
    squeeze = lambda x: x.squeeze()  # handle (T,1,H,W) -> (T,H,W)

    gt = squeeze(gt_arr)
    pred = squeeze(pred_arr)
    valid_mask = squeeze(valid_mask_arr)

    # -----------------------------
    # Optional downsampling (applied per-frame identically)
    # -----------------------------
    if max_resolution is not None:
        H, W = gt.shape[-2:]
        scale_factor = np.min(max_resolution / np.array([H, W]))
        if scale_factor < 1:
            downscaler = torch.nn.Upsample(scale_factor=float(scale_factor), mode="nearest")

            gt = downscaler(torch.as_tensor(gt).unsqueeze(1)).squeeze(1).numpy()
            pred = downscaler(torch.as_tensor(pred).unsqueeze(1)).squeeze(1).numpy()
            valid_mask = (
                downscaler(torch.as_tensor(valid_mask).unsqueeze(1).float())
                .squeeze(1).bool().numpy()
            )

    assert gt.shape == pred.shape == valid_mask.shape, f"{gt.shape}, {pred.shape}, {valid_mask.shape}"

    # -----------------------------
    # Flatten ALL frames
    # -----------------------------
    gt_masked = gt[valid_mask].reshape(-1, 1)        # (N, 1)
    pred_masked = pred[valid_mask].reshape(-1, 1)    # (N, 1)

    # -----------------------------
    # Solve least squares over ALL pixels (T*H*W)
    # -----------------------------
    _ones = np.ones_like(pred_masked)
    A = np.concatenate([pred_masked, _ones], axis=-1)   # (N, 2)

    X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
    scale, shift = X

    # Apply to original resolution (not the downsampled)
    aligned_pred = pred_arr * scale + shift
    aligned_pred = aligned_pred.reshape(ori_shape)

    if return_scale_shift:
        return aligned_pred, scale, shift
    else:
        return aligned_pred
    

def align_depth_least_square(
    gt_arr: np.ndarray,
    pred_arr: np.ndarray,
    valid_mask_arr: np.ndarray,
    return_scale_shift=True,
    max_resolution=None,
):
    ori_shape = pred_arr.shape  # input shape

    gt = gt_arr.squeeze()  # [H, W]
    pred = pred_arr.squeeze()
    valid_mask = valid_mask_arr.squeeze()

    # Downsample
    if max_resolution is not None:
        scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
        if scale_factor < 1:
            downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
            gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
            pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
            valid_mask = (
                downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
                .bool()
                .numpy()
            )

    assert (
        gt.shape == pred.shape == valid_mask.shape
    ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"

    gt_masked = gt[valid_mask].reshape((-1, 1))
    pred_masked = pred[valid_mask].reshape((-1, 1))

    # numpy solver
    _ones = np.ones_like(pred_masked)
    A = np.concatenate([pred_masked, _ones], axis=-1)
    X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
    scale, shift = X

    aligned_pred = pred_arr * scale + shift

    # restore dimensions
    aligned_pred = aligned_pred.reshape(ori_shape)

    if return_scale_shift:
        return aligned_pred, scale, shift
    else:
        return aligned_pred


# ******************** disparity space ********************
def depth2disparity(depth, return_mask=False):
    if isinstance(depth, torch.Tensor):
        disparity = torch.zeros_like(depth)
    elif isinstance(depth, np.ndarray):
        disparity = np.zeros_like(depth)
    non_negtive_mask = depth > 0
    disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
    if return_mask:
        return disparity, non_negtive_mask
    else:
        return disparity


def disparity2depth(disparity, **kwargs):
    return depth2disparity(disparity, **kwargs)