File size: 6,511 Bytes
66003a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np
from typing import Union

ArrayLike = Union[np.ndarray, torch.Tensor]


def _is_numpy(x: ArrayLike) -> bool:
    return isinstance(x, np.ndarray)


def _is_torch(x: ArrayLike) -> bool:
    return isinstance(x, torch.Tensor)


def _ensure_torch(x: ArrayLike) -> torch.Tensor:
    """Convert input to torch tensor if it's not already one."""
    if _is_numpy(x):
        return torch.from_numpy(x)
    elif _is_torch(x):
        return x
    else:
        return torch.tensor(x)


def single_undistortion(params, tracks_normalized):
    """
    Apply undistortion to the normalized tracks using the given distortion parameters once.

    Args:
        params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
        tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].

    Returns:
        torch.Tensor: Undistorted normalized tracks tensor.
    """
    params = _ensure_torch(params)
    tracks_normalized = _ensure_torch(tracks_normalized)

    u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
    u_undist, v_undist = apply_distortion(params, u, v)
    return torch.stack([u_undist, v_undist], dim=-1)


def iterative_undistortion(params, tracks_normalized, max_iterations=100, max_step_norm=1e-10, rel_step_size=1e-6):
    """
    Iteratively undistort the normalized tracks using the given distortion parameters.

    Args:
        params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
        tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
        max_iterations (int): Maximum number of iterations for the undistortion process.
        max_step_norm (float): Maximum step norm for convergence.
        rel_step_size (float): Relative step size for numerical differentiation.

    Returns:
        torch.Tensor: Undistorted normalized tracks tensor.
    """
    params = _ensure_torch(params)
    tracks_normalized = _ensure_torch(tracks_normalized)

    B, N, _ = tracks_normalized.shape
    u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
    original_u, original_v = u.clone(), v.clone()

    eps = torch.finfo(u.dtype).eps
    for idx in range(max_iterations):
        u_undist, v_undist = apply_distortion(params, u, v)
        dx = original_u - u_undist
        dy = original_v - v_undist

        step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
        step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)

        J_00 = (apply_distortion(params, u + step_u, v)[0] - apply_distortion(params, u - step_u, v)[0]) / (2 * step_u)
        J_01 = (apply_distortion(params, u, v + step_v)[0] - apply_distortion(params, u, v - step_v)[0]) / (2 * step_v)
        J_10 = (apply_distortion(params, u + step_u, v)[1] - apply_distortion(params, u - step_u, v)[1]) / (2 * step_u)
        J_11 = (apply_distortion(params, u, v + step_v)[1] - apply_distortion(params, u, v - step_v)[1]) / (2 * step_v)

        J = torch.stack([torch.stack([J_00 + 1, J_01], dim=-1), torch.stack([J_10, J_11 + 1], dim=-1)], dim=-2)

        delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))

        u += delta[..., 0]
        v += delta[..., 1]

        if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
            break

    return torch.stack([u, v], dim=-1)


def apply_distortion(extra_params, u, v):
    """
    Applies radial or OpenCV distortion to the given 2D points.

    Args:
        extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
        u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
        v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.

    Returns:
        points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
    """
    extra_params = _ensure_torch(extra_params)
    u = _ensure_torch(u)
    v = _ensure_torch(v)

    num_params = extra_params.shape[1]

    if num_params == 1:
        # Simple radial distortion
        k = extra_params[:, 0]
        u2 = u * u
        v2 = v * v
        r2 = u2 + v2
        radial = k[:, None] * r2
        du = u * radial
        dv = v * radial

    elif num_params == 2:
        # RadialCameraModel distortion
        k1, k2 = extra_params[:, 0], extra_params[:, 1]
        u2 = u * u
        v2 = v * v
        r2 = u2 + v2
        radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
        du = u * radial
        dv = v * radial

    elif num_params == 4:
        # OpenCVCameraModel distortion
        k1, k2, p1, p2 = (extra_params[:, 0], extra_params[:, 1], extra_params[:, 2], extra_params[:, 3])
        u2 = u * u
        v2 = v * v
        uv = u * v
        r2 = u2 + v2
        radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
        du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
        dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
    else:
        raise ValueError("Unsupported number of distortion parameters")

    u = u.clone() + du
    v = v.clone() + dv

    return u, v


if __name__ == "__main__":
    import random
    import pycolmap

    max_diff = 0
    for i in range(1000):
        # Define distortion parameters (assuming 1 parameter for simplicity)
        B = random.randint(1, 500)
        track_num = random.randint(100, 1000)
        params = torch.rand((B, 1), dtype=torch.float32)  # Batch size 1, 4 parameters
        tracks_normalized = torch.rand((B, track_num, 2), dtype=torch.float32)  # Batch size 1, 5 points

        # Undistort the tracks
        undistorted_tracks = iterative_undistortion(params, tracks_normalized)

        for b in range(B):
            pycolmap_intri = np.array([1, 0, 0, params[b].item()])
            pycam = pycolmap.Camera(model="SIMPLE_RADIAL", width=1, height=1, params=pycolmap_intri, camera_id=0)

            undistorted_tracks_pycolmap = pycam.cam_from_img(tracks_normalized[b].numpy())
            diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
            max_diff = max(max_diff, diff)
            print(f"diff: {diff}, max_diff: {max_diff}")

    import pdb

    pdb.set_trace()