lihao57 commited on
Commit
e26dfd8
·
1 Parent(s): cc5da1d

add utils and support for visualizing Bezier curve

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +212 -92
  3. requirements.txt +3 -1
  4. utils/bezier.py +122 -0
  5. utils/camera.py +592 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .vscode
2
  .gradio
 
 
1
  .vscode
2
  .gradio
3
+ **/__pycache__
app.py CHANGED
@@ -8,123 +8,245 @@
8
  @Contact : 2909171338@qq.com
9
  """
10
 
11
- import argparse
12
  import gradio as gr
13
- from PIL import Image, ImageDraw
14
  import io
 
15
  import matplotlib.pyplot as plt
16
  import numpy as np
17
- from datasets import load_dataset
 
 
18
 
19
- ds = None
20
- DATASET_NAME = None
21
- LOCAL = False
22
- FAST = False
23
- SPLIT = "all"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def get_dataset():
 
27
  """
28
- get dataset
29
 
30
  Args:
31
- None
32
 
33
  Returns:
34
- ds (datasets.Dataset): dataset
35
  """
36
- global ds
37
- if ds is None:
38
- if LOCAL:
39
- ds = load_dataset("imagefolder", data_dir=DATASET_NAME)
 
 
40
  else:
41
- ds = load_dataset(DATASET_NAME)
42
- return ds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
- def selector_change_callback(value):
46
  """
47
- callback function for split selector
48
 
49
  Args:
50
- value (str): selected split, value must be one of ["train", "test"]
 
51
 
52
  Returns:
53
- slider_info (dict): updated slider info
54
- image (np.ndarray): updated image
 
55
  """
56
- ds = get_dataset()
57
- maximum = len(ds[value]) - 1
58
- slider_info = gr.update(minimum=0, maximum=maximum, value=0)
59
- image = show_image(split=value, index=0)
60
- return slider_info, image
61
 
 
 
 
 
 
62
 
63
- def draw_lines(image, lines):
 
64
  """
65
- draw lines on image
66
 
67
  Args:
68
  image (np.ndarray): input image
69
  lines (np.ndarray): list of lines, with shape [N, 2, 2]
 
 
 
70
 
71
  Returns:
72
- image (PIL.Image): drawn image
73
  """
74
- if FAST:
75
- image = Image.fromarray(image)
76
- draw = ImageDraw.Draw(image)
77
- for pts in lines:
78
- pts = pts - 0.5
79
- pts_list = [tuple(p) for p in pts]
80
- draw.line(pts_list, fill="orange", width=2)
81
- draw.circle(pts_list[0], 3, fill="#33FFFF")
82
- draw.circle(pts_list[-1], 3, fill="#33FFFF")
83
  else:
84
- height, width = image.shape[:2]
85
- fig = plt.figure()
86
- fig.set_size_inches(width / height, 1, forward=False)
87
- ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
88
- ax.set_axis_off()
89
- fig.add_axes(ax)
90
- plt.xlim([-0.5, width - 0.5])
91
- plt.ylim([height - 0.5, -0.5])
92
- plt.imshow(image)
93
- for pts in lines:
94
- pts = pts - 0.5
95
- plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5)
96
- plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5)
97
-
98
- buf = io.BytesIO()
99
- fig.savefig(buf, format="png", dpi=height, bbox_inches=0)
100
- buf.seek(0)
101
- plt.close(fig)
102
- image = Image.open(buf)
 
 
 
 
 
 
 
 
103
  return image
104
 
105
 
106
- def show_image(split, index):
107
  """
108
- show image
109
 
110
  Args:
111
  split (str): split name, value must be one of ["train", "test"]
112
- index (int): index of the example
 
113
 
114
  Returns:
115
- image (PIL.Image): drawn image
 
116
  """
117
- ds = get_dataset()
118
- example = ds[split][index]
119
- image = np.array(example["image"])
120
- lines = np.array(example["lines"])
121
- image = draw_lines(image, lines)
122
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  def main():
126
  """
127
- main
128
 
129
  Args:
130
  None
@@ -133,30 +255,28 @@ def main():
133
  None
134
  """
135
  with gr.Blocks() as demo:
136
- if SPLIT == "all":
137
- choices = ["train", "test"]
138
- else:
139
- choices = [SPLIT]
140
- split_selector = gr.Dropdown(choices, label="Split", value=choices[0])
141
- index_slider = gr.Slider(0, 1, step=1, label="Index", value=0)
142
- output = gr.Image()
143
-
144
- split_selector.change(selector_change_callback, split_selector, [index_slider, output])
145
- index_slider.change(show_image, [split_selector, index_slider], output)
146
- demo.load(selector_change_callback, split_selector, [index_slider, output])
 
 
 
 
 
147
  demo.launch(share=False)
148
 
149
 
150
  if __name__ == "__main__":
151
- argparser = argparse.ArgumentParser()
152
- argparser.add_argument("-n", "--dataset_name", type=str, help="dataset name", default="lh9171338/Wireframe")
153
- argparser.add_argument("-l", "--local", type=bool, help="use local data or not", default=False)
154
- argparser.add_argument("-f", "--fast", type=bool, help="whether to use fast drawing method", default=False)
155
- argparser.add_argument("-s", "--split", type=str, help="split", default="all", choices=["all", "train", "test"])
156
- args = argparser.parse_args()
157
- print(args)
158
- DATASET_NAME = args.dataset_name
159
- LOCAL = args.local
160
- FAST = args.fast
161
- SPLIT = args.split
162
  main()
 
8
  @Contact : 2909171338@qq.com
9
  """
10
 
11
+ import os
12
  import gradio as gr
13
+ from PIL import Image
14
  import io
15
+ import logging
16
  import matplotlib.pyplot as plt
17
  import numpy as np
18
+ from datasets import load_dataset, DatasetDict
19
+ import utils.camera as cam
20
+ import utils.bezier as bez
21
 
 
 
 
 
 
22
 
23
+ dataset_dict = dict()
24
+ dataset = None
25
+ default_split_selector_info = dict(
26
+ choices=["train", "test"],
27
+ label="Split",
28
+ value="train",
29
+ interactive=False,
30
+ )
31
+ default_index_slider_info = dict(
32
+ minimum=0,
33
+ maximum=1,
34
+ step=1,
35
+ label="Index",
36
+ value=0,
37
+ interactive=False,
38
+ )
39
+ default_order_slider_info = dict(
40
+ minimum=0,
41
+ maximum=6,
42
+ step=1,
43
+ label="Order",
44
+ value=0,
45
+ interactive=False,
46
+ )
47
+ sample_info = dict(
48
+ dataset=dataset,
49
+ split="train",
50
+ index=0,
51
+ order=0,
52
+ image1=None,
53
+ image2=None,
54
+ )
55
 
56
+
57
+ def get_dataset(dataset_name):
58
  """
59
+ Get dataset
60
 
61
  Args:
62
+ dataset_name (str): dataset name or path
63
 
64
  Returns:
65
+ dataset (datasets.Dataset): dataset
66
  """
67
+ global dataset_dict
68
+ if dataset_name in dataset_dict:
69
+ dataset = dataset_dict[dataset_name]
70
+ else:
71
+ if os.path.exists(dataset_name):
72
+ dataset = load_dataset("imagefolder", data_dir=dataset_name)
73
  else:
74
+ dataset = load_dataset(dataset_name)
75
+ dataset_dict[dataset_name] = dataset
76
+ return dataset
77
+
78
+
79
+ def submit_callback(dataset_name, order):
80
+ """
81
+ Submit callback function
82
+
83
+ Args:
84
+ dataset_name (str): dataset name or path
85
+ order (int): order of the Bezier curve
86
+
87
+ Returns:
88
+ split_selector_info (dict): updated split selector info
89
+ index_slider_info (dict): updated index slider info
90
+ order_slider_info (dict): updated slider info
91
+ image1 (np.ndarray): updated image
92
+ image2 (np.ndarray): updated image
93
+ """
94
+ global dataset
95
+ try:
96
+ dataset = get_dataset(dataset_name)
97
+ except Exception as e:
98
+ dataset = None
99
+ logging.error(f"Load dataset failed: {e}")
100
+ split_selector_info = gr.update(**default_split_selector_info)
101
+ index_slider_info = gr.update(**default_index_slider_info)
102
+ order_slider_info = gr.update(**default_order_slider_info)
103
+ return split_selector_info, index_slider_info, order_slider_info, None, None
104
+
105
+ if not isinstance(dataset, DatasetDict):
106
+ dataset = {str(dataset.split): dataset}
107
+ splits = list(dataset.keys())
108
+ split = splits[0]
109
+ maximum = len(dataset[split]) - 1
110
+ index = 0
111
+ split_selector_info = gr.update(choices=splits, value=split, interactive=True)
112
+ index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True)
113
+ order_slider_info = gr.update(interactive=True)
114
+ image1, image2 = show_image(split=split, index=index, order=order)
115
+ return split_selector_info, index_slider_info, order_slider_info, image1, image2
116
 
117
 
118
+ def selector_change_callback(split, order):
119
  """
120
+ Selector change callback function
121
 
122
  Args:
123
+ split (str): selected split, value must be one of ["train", "test"]
124
+ order (int): order of the Bezier curve
125
 
126
  Returns:
127
+ index_slider_info (dict): updated slider info
128
+ image1 (np.ndarray): updated image
129
+ image2 (np.ndarray): updated image
130
  """
131
+ global dataset
132
+ if dataset is None:
133
+ index_slider_info = gr.update(**default_index_slider_info)
134
+ return index_slider_info, None, None
 
135
 
136
+ maximum = len(dataset[split]) - 1
137
+ index = 0
138
+ index_slider_info = gr.update(minimum=0, maximum=maximum, value=index)
139
+ image1, image2 = show_image(split=split, index=0, order=order)
140
+ return index_slider_info, image1, image2
141
 
142
+
143
+ def draw_lines(image, lines, camera_type="pinhole", camera_coeff=None, order=None):
144
  """
145
+ Draw lines on image
146
 
147
  Args:
148
  image (np.ndarray): input image
149
  lines (np.ndarray): list of lines, with shape [N, 2, 2]
150
+ camera_type (str): camera type, value must be one of ["pinhole", "fisheye", "spherical"]
151
+ camera_coeff (dict | None): dict of camera coefficients
152
+ order (int | None): order of the Bezier curve
153
 
154
  Returns:
155
+ image (PIL.Image | None): drawn image
156
  """
157
+ if order == 0: # Show original image
158
+ return image
159
+
160
+ assert camera_type in ["pinhole", "fisheye", "spherical"]
161
+ height, width = image.shape[:2]
162
+ if camera_type == "pinhole":
163
+ camera = cam.Pinhole(coeff=camera_coeff)
164
+ elif camera_type == "fisheye":
165
+ camera = cam.Fisheye(coeff=camera_coeff)
166
  else:
167
+ camera = cam.Spherical(image_size=(width, height), coeff=camera_coeff)
168
+
169
+ fig = plt.figure()
170
+ fig.set_size_inches(width / height, 1, forward=False)
171
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
172
+ ax.set_axis_off()
173
+ fig.add_axes(ax)
174
+ plt.xlim([-0.5, width - 0.5])
175
+ plt.ylim([height - 0.5, -0.5])
176
+ plt.imshow(image)
177
+ lines = camera.truncate_line(lines)
178
+ pts_list = camera.interp_line(lines)
179
+ if order is not None: # Draw Bezier curve
180
+ bezier = bez.Bezier(order=order)
181
+ lines, t_list = bezier.fit_line(pts_list)
182
+ pts_list = bezier.interp_line(lines, t_list)
183
+
184
+ for pts in pts_list:
185
+ pts = pts - 0.5
186
+ plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5)
187
+ plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5)
188
+
189
+ buf = io.BytesIO()
190
+ fig.savefig(buf, format="png", dpi=height, bbox_inches=0)
191
+ buf.seek(0)
192
+ plt.close(fig)
193
+ image = Image.open(buf)
194
  return image
195
 
196
 
197
+ def show_image(split, index, order):
198
  """
199
+ Show image
200
 
201
  Args:
202
  split (str): split name, value must be one of ["train", "test"]
203
+ index (int): index of the sample
204
+ order (int): order of the Bezier curve
205
 
206
  Returns:
207
+ image1 (PIL.Image): drawn image
208
+ image2 (PIL.Image): drawn image
209
  """
210
+ global dataset
211
+ if dataset is None:
212
+ return None, None
213
+
214
+ global sample_info
215
+ old_sample_info = dict(
216
+ dataset=sample_info["dataset"],
217
+ split=sample_info["split"],
218
+ index=sample_info["index"],
219
+ order=sample_info["order"],
220
+ )
221
+ new_sample_info = dict(dataset=dataset, split=split, index=index, order=order)
222
+ if old_sample_info == new_sample_info: # No need to update
223
+ logging.info("No need to update")
224
+ return sample_info["image1"], sample_info["image2"]
225
+
226
+ old_sample_info.pop("order")
227
+ new_sample_info.pop("order")
228
+ sample = dataset[split][index]
229
+ image = np.array(sample["image"])
230
+ lines = np.array(sample["lines"])
231
+ camera_type = sample.get("camera_type", "pinhole")
232
+ camera_coeff = sample.get("camera_coeff", None)
233
+ if old_sample_info == new_sample_info: # No need to update origin label
234
+ image1 = sample_info["image1"]
235
+ logging.info("Only update Bezier curve")
236
+ else:
237
+ image1 = draw_lines(image, lines, camera_type, camera_coeff)
238
+ image2 = draw_lines(image, lines, camera_type, camera_coeff, order)
239
+ sample_info.update(new_sample_info)
240
+ sample_info["order"] = order
241
+ sample_info["image1"] = image1
242
+ sample_info["image2"] = image2
243
+ logging.info("Update")
244
+ return image1, image2
245
 
246
 
247
  def main():
248
  """
249
+ Main
250
 
251
  Args:
252
  None
 
255
  None
256
  """
257
  with gr.Blocks() as demo:
258
+ dataset_textbox = gr.Textbox(label="Dataset name or path")
259
+ split_selector = gr.Dropdown(**default_split_selector_info)
260
+ index_slider = gr.Slider(**default_index_slider_info)
261
+ order_slider = gr.Slider(**default_order_slider_info)
262
+ with gr.Row():
263
+ image1 = gr.Image(label="Original Label")
264
+ image2 = gr.Image(label="Bezier Curve")
265
+
266
+ dataset_textbox.submit(
267
+ submit_callback,
268
+ [dataset_textbox, order_slider],
269
+ [split_selector, index_slider, order_slider, image1, image2],
270
+ )
271
+ split_selector.change(selector_change_callback, [split_selector, order_slider], [index_slider, image1, image2])
272
+ index_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2])
273
+ order_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2])
274
  demo.launch(share=False)
275
 
276
 
277
  if __name__ == "__main__":
278
+ # set base logging config
279
+ fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s"
280
+ logging.basicConfig(format=fmt, level=logging.INFO)
281
+
 
 
 
 
 
 
 
282
  main()
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  datasets
2
  matplotlib
3
- numpy
 
4
  pillow
 
 
1
  datasets
2
  matplotlib
3
+ numpy<2
4
+ opencv-python
5
  pillow
6
+ scipy
utils/bezier.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ """
4
+ @File : bezier.py
5
+ @Time : 2025/9/3 15:25:00
6
+ @Author : lh9171338
7
+ @Version : 1.0
8
+ @Contact : 2909171338@qq.com
9
+ """
10
+
11
+ import numpy as np
12
+ import cv2
13
+ from scipy.special import comb
14
+
15
+
16
+ class Bezier:
17
+ """
18
+ Bezier
19
+ """
20
+
21
+ def __init__(self, order=1, **kwargs):
22
+ self.set_order(order)
23
+
24
+ def set_order(self, order):
25
+ """
26
+ set order
27
+
28
+ Args:
29
+ order (int): order
30
+
31
+ Returns:
32
+ None
33
+ """
34
+ p = comb(order, np.arange(order + 1))
35
+ k = np.arange(0, order + 1)
36
+ t = np.linspace(0, 1, order + 1)[:, None]
37
+ coeff_matrix = p * (t**k) * ((1 - t) ** (order - k))
38
+ inv_coeff_matrix = np.linalg.inv(coeff_matrix)
39
+
40
+ self.order = order
41
+ self.p = p
42
+ self.k = k
43
+ self.inv_coeff_matrix = inv_coeff_matrix
44
+
45
+ def fit_line(self, pts_list):
46
+ """
47
+ Fit line
48
+
49
+ Args:
50
+ pts_list (list): list of pts
51
+
52
+ Returns:
53
+ lines (np.ndarray): lines, shape [N, O + 1, 2]
54
+ t_list (list): list of t
55
+ """
56
+ lines, t_list = [], []
57
+ t0 = np.linspace(0, 1, self.order + 1)
58
+ for pts in pts_list:
59
+ if len(pts) < 2:
60
+ continue
61
+ dists = np.linalg.norm(pts[1:] - pts[:-1], axis=-1)
62
+ dists = np.cumsum(dists)
63
+ t = np.concatenate((np.zeros(1), dists / dists[-1]))
64
+ indices = [np.argmin(abs(t - i)) for i in t0]
65
+ line = pts[indices]
66
+ lines.append(line)
67
+ t_list.append(t)
68
+
69
+ lines = np.asarray(lines)
70
+ return lines, t_list
71
+
72
+ def interp_line(self, lines, t_list=None, num=None, resolution=0.1):
73
+ """
74
+ Interpolate line
75
+
76
+ Args:
77
+ lines (np.ndarray): lines, shape [N, O + 1, 2]
78
+ t_list (list | Nonr): list of t
79
+ num (int | None): number of points to interpolate
80
+ resolution (float): resolution
81
+
82
+ Returns:
83
+ pts_list (list): list of interpolated points
84
+ """
85
+ assert lines.shape[1] == self.order + 1
86
+
87
+ if t_list is None:
88
+ t_list = []
89
+ for line in lines:
90
+ K = num or int(round(max(abs(line[-1] - line[0])) / resolution)) + 1
91
+ t = np.linspace(0, 1, K)
92
+ t_list.append(t)
93
+
94
+ pts_list = []
95
+ for line, t in zip(lines, t_list):
96
+ control_points = np.matmul(self.inv_coeff_matrix, line)
97
+ t = t[:, None]
98
+ coeff_matrix = self.p * (t**self.k) * ((1 - t) ** (self.order - self.k))
99
+ pts = np.matmul(coeff_matrix, control_points)
100
+ pts_list.append(pts)
101
+
102
+ return pts_list
103
+
104
+ def insert_line(self, image, lines, color, thickness=1):
105
+ """
106
+ Insert line
107
+
108
+ Args:
109
+ image (np.ndarray): image
110
+ lines (np.ndarray): lines, shape [N, 2, 2]
111
+ color (tuple): color
112
+ thickness (int): thickness
113
+
114
+ Returns:
115
+ image (np.ndarray): image
116
+ """
117
+ pts_list = self.interp_line(lines)
118
+ for pts in pts_list:
119
+ pts = np.round(pts).astype(np.int32)
120
+ cv2.polylines(image, [pts], isClosed=False, color=color, thickness=thickness)
121
+
122
+ return image
utils/camera.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ """
4
+ @File : camera.py
5
+ @Time : 2025/9/3 15:25:00
6
+ @Author : lh9171338
7
+ @Version : 1.0
8
+ @Contact : 2909171338@qq.com
9
+ """
10
+
11
+ import numpy as np
12
+ import cv2
13
+
14
+
15
+ class Camera:
16
+ """
17
+ Base Camera
18
+
19
+ Args:
20
+ coeff (dict | None): camera coefficients
21
+ **kwargs: keyword arguments
22
+ """
23
+
24
+ def __init__(self, coeff=None, **kwargs):
25
+ self.coeff = coeff
26
+ self.format_coeff()
27
+
28
+ def format_coeff(self):
29
+ """
30
+ Format coeff
31
+
32
+ Args:
33
+ None
34
+
35
+ Returns:
36
+ None
37
+ """
38
+ if self.coeff:
39
+ self.coeff = {k: np.array(v) for k, v in self.coeff.items()}
40
+
41
+ def load_coeff(self, filename):
42
+ """
43
+ Load coeff
44
+
45
+ Args:
46
+ filename (str): filename
47
+
48
+ Returns:
49
+ None
50
+ """
51
+ fs = cv2.FileStorage(filename, cv2.FileStorage_READ)
52
+ K = fs.getNode("K").mat()
53
+ D = fs.getNode("D").mat()
54
+ fs.release()
55
+ self.coeff = {"K": K, "D": D}
56
+
57
+ def save_coeff(self, filename):
58
+ """
59
+ Save coeff
60
+
61
+ Args:
62
+ filename (str): filename
63
+
64
+ Returns:
65
+ None
66
+ """
67
+ fs = cv2.FileStorage(filename, cv2.FileStorage_WRITE)
68
+ fs.write("K", self.coeff["K"])
69
+ fs.write("D", self.coeff["D"])
70
+ fs.release()
71
+
72
+ def distort_point(self, undistorted):
73
+ """
74
+ Distort point
75
+
76
+ Args:
77
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
78
+
79
+ Returns:
80
+ distorted (np.ndarray): distorted points, shape [N, 2]
81
+ """
82
+ raise NotImplementedError
83
+
84
+ def undistort_point(self, distorted):
85
+ """
86
+ Undistort point
87
+
88
+ Args:
89
+ distorted (np.ndarray): distorted points, shape [N, 2]
90
+
91
+ Returns:
92
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
93
+ """
94
+ raise NotImplementedError
95
+
96
+ def distort_image(self, image, transform=None):
97
+ """
98
+ Distort image
99
+
100
+ Args:
101
+ image (np.ndarray): image
102
+ transform (list): transform, [tx, ty, sx, sy]
103
+
104
+ Returns:
105
+ image (np.ndarray): distorted image
106
+ """
107
+ if transform is None:
108
+ transform = [0.0, 0.0, 1.0, 1.0]
109
+ tx, ty, sx, sy = transform[0], transform[1], transform[2], transform[3]
110
+
111
+ height, width = image.shape[0], image.shape[1]
112
+
113
+ distorted = np.mgrid[0:width, 0:height].T.reshape(-1, 2).astype(np.float64)
114
+ undistorted = self.undistort_point(distorted)
115
+ undistorted = undistorted.reshape(height, width, 2)
116
+ map1 = (undistorted[:, :, 0].astype(np.float32) - tx) / sx
117
+ map2 = (undistorted[:, :, 1].astype(np.float32) - ty) / sy
118
+
119
+ image = cv2.remap(image, map1, map2, cv2.INTER_CUBIC)
120
+
121
+ return image
122
+
123
+ def undistort_image(self, image, transform=None):
124
+ """
125
+ Undistort image
126
+
127
+ Args:
128
+ image (np.ndarray): image
129
+ transform (list): transform, [tx, ty, sx, sy]
130
+
131
+ Returns:
132
+ image (np.ndarray): undistorted image
133
+ """
134
+ if transform is None:
135
+ transform = [0.0, 0.0, 1.0, 1.0]
136
+ tx, ty, sx, sy = transform[0], transform[1], transform[2], transform[3]
137
+
138
+ height, width = image.shape[0], image.shape[1]
139
+
140
+ undistorted = np.mgrid[0:width, 0:height].T.reshape(-1, 2).astype(np.float64)
141
+ undistorted[:, 0] = undistorted[:, 0] * sx + tx
142
+ undistorted[:, 1] = undistorted[:, 1] * sy + ty
143
+ distorted = self.distort_point(undistorted)
144
+ distorted = distorted.reshape(height, width, 2)
145
+ map1 = distorted[:, :, 0].astype(np.float32)
146
+ map2 = distorted[:, :, 1].astype(np.float32)
147
+ image = cv2.remap(image, map1, map2, cv2.INTER_CUBIC)
148
+
149
+ return image
150
+
151
+ def interp_line(self, lines, num=None, resolution=1.0):
152
+ """
153
+ Interpolate line
154
+
155
+ Args:
156
+ lines (np.ndarray): lines, shape [N, 2, 2]
157
+ num (int | None): number of interpolated points per line
158
+ resolution (float): resolution of interpolation
159
+
160
+ Returns:
161
+ pts_list (list): list of interpolated points
162
+ """
163
+ raise NotImplementedError
164
+
165
+ def interp_arc(self, arcs, num=None, resolution=0.01):
166
+ """
167
+ Interpolate arc
168
+
169
+ Args:
170
+ arcs (np.ndarray): arcs, shape [N, 2, 2]
171
+ num (int | None): number of interpolated points per line
172
+ resolution (float): resolution of interpolation
173
+
174
+ Returns:
175
+ pts_list (list): list of interpolated points
176
+ """
177
+ resolution *= np.pi / 180.0
178
+
179
+ pts_list = []
180
+ for arc in arcs:
181
+ pt1, pt2 = arc[0], arc[1]
182
+ normal = np.cross(pt1, pt2)
183
+ normal /= np.linalg.norm(normal)
184
+ angle = np.arccos(normal[2])
185
+ axes = np.array([-normal[1], normal[0], 0])
186
+ axes /= max(np.linalg.norm(axes), np.finfo(np.float64).eps)
187
+ rotation_vector = angle * axes
188
+ rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
189
+ pt1 = np.matmul(rotation_matrix.T, pt1[:, None]).flatten()
190
+ pt2 = np.matmul(rotation_matrix.T, pt2[:, None]).flatten()
191
+ min_angle = np.arctan2(pt1[1], pt1[0])
192
+ max_angle = np.arctan2(pt2[1], pt2[0])
193
+ if max_angle < min_angle:
194
+ max_angle += 2 * np.pi
195
+
196
+ K = int(round((max_angle - min_angle) / resolution) + 1) if num is None else num
197
+ angles = np.linspace(min_angle, max_angle, K)
198
+ pts = np.hstack((np.cos(angles)[:, None], np.sin(angles)[:, None], np.zeros((K, 1))))
199
+ pts = np.matmul(rotation_matrix, pts.T).T
200
+ pts_list.append(pts)
201
+
202
+ return pts_list
203
+
204
+ def insert_line(self, image, pts_list, color, thickness=1):
205
+ """
206
+ Insert line
207
+
208
+ Args:
209
+ image (np.ndarray): image
210
+ pts_list (list): list of points
211
+ color (tuple): color
212
+ thickness (int): thickness
213
+
214
+ Returns:
215
+ image (np.ndarray): image
216
+ """
217
+ for pts in pts_list:
218
+ pts = np.round(pts).astype(np.int32)
219
+ cv2.polylines(image, [pts], isClosed=False, color=color, thickness=thickness)
220
+
221
+ return image
222
+
223
+ def truncate_line(self, lines):
224
+ """
225
+ Truncate line
226
+
227
+ Args:
228
+ lines (np.ndarray): lines, shape [N, 2, 2]
229
+ image_size (tuple): image size, [width, height]
230
+
231
+ Returns:
232
+ lines (np.ndarray): truncated lines, shape [M, 2, 2]
233
+ """
234
+ return lines
235
+
236
+
237
+ class Pinhole(Camera):
238
+ """
239
+ Pinhole camera
240
+ """
241
+
242
+ def distort_point(self, undistorted):
243
+ """
244
+ Distort point
245
+
246
+ Args:
247
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
248
+
249
+ Returns:
250
+ distorted (np.ndarray): distorted points, shape [N, 2]
251
+ """
252
+ if self.coeff is not None:
253
+ K, D = self.coeff["K"], self.coeff["D"]
254
+ fx, fy = K[0, 0], K[1, 1]
255
+ cx, cy = K[0, 2], K[1, 2]
256
+
257
+ undistorted = undistorted.copy().astype(np.float64)
258
+ undistorted[:, 0] = (undistorted[:, 0] - cx) / fx
259
+ undistorted[:, 1] = (undistorted[:, 1] - cy) / fy
260
+ undistorted = np.hstack((undistorted, np.ones((undistorted.shape[0], 1), np.float64)))
261
+ distorted = cv2.projectPoints(undistorted.reshape(1, -1, 3), (0, 0, 0), (0, 0, 0), K, D)[0].reshape(-1, 2)
262
+ else:
263
+ distorted = undistorted
264
+
265
+ return distorted
266
+
267
+ def undistort_point(self, distorted):
268
+ """
269
+ Undistort point
270
+
271
+ Args:
272
+ distorted (np.ndarray): distorted points, shape [N, 2]
273
+
274
+ Returns:
275
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
276
+ """
277
+ if self.coeff is not None:
278
+ K, D = self.coeff["K"], self.coeff["D"]
279
+ distorted = distorted.copy().astype(np.float64)
280
+ undistorted = cv2.undistortPoints(distorted.reshape(1, -1, 2), K, D, R=None, P=K).reshape(-1, 2)
281
+ else:
282
+ undistorted = distorted
283
+
284
+ return undistorted
285
+
286
+ def interp_line(self, lines, num=None, resolution=0.1):
287
+ """
288
+ Interpolate line
289
+
290
+ Args:
291
+ lines (np.ndarray): lines, shape [N, 2, 2]
292
+ num (int | None): number of interpolated points per line
293
+ resolution (float): resolution of interpolation
294
+
295
+ Returns:
296
+ pts_list (list): list of interpolated points
297
+ """
298
+ distorted = lines.reshape(-1, 2)
299
+ undistorted = self.undistort_point(distorted)
300
+ lines = undistorted.reshape(-1, 2, 2)
301
+
302
+ pts_list = []
303
+ for line in lines:
304
+ K = num or int(round(max(abs(line[1] - line[0])) / resolution)) + 1
305
+ lambda_ = np.linspace(0, 1, K)[:, None]
306
+ pts = line[1] * lambda_ + line[0] * (1 - lambda_)
307
+ pts = self.distort_point(pts)
308
+ pts_list.append(pts)
309
+
310
+ return pts_list
311
+
312
+ def insert_line(self, image, lines, color, thickness=1):
313
+ """
314
+ Insert line
315
+
316
+ Args:
317
+ image (np.ndarray): image
318
+ lines (np.ndarray): lines, shape [N, 2, 2]
319
+ color (tuple): color
320
+ thickness (int): thickness
321
+
322
+ Returns:
323
+ image (np.ndarray): image
324
+ """
325
+ pts_list = self.interp_line(lines)
326
+ super().insert_line(image, pts_list, color, thickness)
327
+
328
+ return image
329
+
330
+
331
+ class Fisheye(Camera):
332
+ """
333
+ Fisheye camera
334
+ """
335
+
336
+ def distort_point(self, undistorted):
337
+ """
338
+ Distort point
339
+
340
+ Args:
341
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
342
+
343
+ Returns:
344
+ distorted (np.ndarray): distorted points, shape [N, 2]
345
+ """
346
+ undistorted = undistorted.copy().astype(np.float64)
347
+
348
+ K, D = self.coeff["K"], self.coeff["D"]
349
+ fx, fy = K[0, 0], K[1, 1]
350
+ cx, cy = K[0, 2], K[1, 2]
351
+
352
+ undistorted[:, 0] = (undistorted[:, 0] - cx) / fx
353
+ undistorted[:, 1] = (undistorted[:, 1] - cy) / fy
354
+ distorted = cv2.fisheye.distortPoints(undistorted.reshape(1, -1, 2), K, D).reshape(-1, 2)
355
+
356
+ return distorted
357
+
358
+ def undistort_point(self, distorted):
359
+ """
360
+ Undistort point
361
+
362
+ Args:
363
+ distorted (np.ndarray): distorted points, shape [N, 2]
364
+
365
+ Returns:
366
+ undistorted (np.ndarray): undistorted points, shape [N, 2]
367
+ """
368
+ distorted = distorted.copy().astype(np.float64)
369
+
370
+ K, D = self.coeff["K"], self.coeff["D"]
371
+ undistorted = cv2.fisheye.undistortPoints(distorted.reshape(1, -1, 2), K, D, P=K).reshape(-1, 2)
372
+
373
+ return undistorted
374
+
375
+ def interp_line(self, lines, num=None, resolution=0.1):
376
+ """
377
+ Interpolate line
378
+
379
+ Args:
380
+ lines (np.ndarray): lines, shape [N, 2, 2]
381
+ num (int | None): number of interpolated points per line
382
+ resolution (float): resolution of interpolation
383
+
384
+ Returns:
385
+ pts_list (list): list of interpolated points
386
+ """
387
+ distorted = lines.reshape(-1, 2)
388
+ undistorted = self.undistort_point(distorted)
389
+ undistorted = np.hstack((undistorted, np.ones((undistorted.shape[0], 1), np.float64)))
390
+ undistorted = undistorted / np.linalg.norm(undistorted, axis=1, keepdims=True)
391
+
392
+ arcs = undistorted.reshape(-1, 2, 3)
393
+ undistorted_list = self.interp_arc(arcs, num, resolution)
394
+ distorted_list = []
395
+ for undistorted in undistorted_list:
396
+ undistorted = undistorted / (undistorted[:, 2:] + np.finfo(np.float64).eps)
397
+ undistorted = undistorted[:, :2]
398
+ distorted = self.distort_point(undistorted)
399
+ distorted_list.append(distorted)
400
+
401
+ return distorted_list
402
+
403
+ def insert_line(self, image, lines, color, thickness=1):
404
+ """
405
+ Insert line
406
+
407
+ Args:
408
+ image (np.ndarray): image
409
+ lines (np.ndarray): lines, shape [N, 2, 2]
410
+ color (tuple): color
411
+ thickness (int): thickness
412
+
413
+ Returns:
414
+ image (np.ndarray): image
415
+ """
416
+ pts_list = self.interp_line(lines)
417
+ super().insert_line(image, pts_list, color, thickness)
418
+
419
+ return image
420
+
421
+
422
+ class Spherical(Camera):
423
+ """
424
+ Spherical camera
425
+
426
+ Args:
427
+ image_size (tuple): image size, [width, height]
428
+ **kwargs: keyword arguments
429
+ """
430
+
431
+ def __init__(self, image_size, **kwargs):
432
+ super().__init__(**kwargs)
433
+
434
+ self.image_size = image_size
435
+
436
+ def distort_point(self, undistorted):
437
+ """
438
+ Distort point
439
+
440
+ Args:
441
+ undistorted (np.ndarray): undistorted points, shape [N, 3]
442
+
443
+ Returns:
444
+ distorted (np.ndarray): distorted points, shape [N, 2]
445
+ """
446
+ undistorted = undistorted.copy().astype(np.float64)
447
+ width, height = self.image_size
448
+
449
+ if self.coeff is not None:
450
+ K, D = self.coeff["K"], self.coeff["D"]
451
+ cx = cy = (height - 1.0) / 2.0
452
+
453
+ mask = undistorted[:, 2] < 0
454
+ undistorted[mask, 0] = -undistorted[mask, 0]
455
+ undistorted[mask, 2] = -undistorted[mask, 2]
456
+ undistorted = undistorted / (undistorted[:, 2:] + np.finfo(np.float64).eps)
457
+ undistorted = undistorted[:, :2]
458
+ distorted = cv2.fisheye.distortPoints(undistorted.reshape(1, -1, 2), K, D).reshape(-1, 2)
459
+ x = (distorted[:, 0] - cx) / cx
460
+ y = (distorted[:, 1] - cy) / cy
461
+ theta = np.arctan2(y, x)
462
+ phi = np.sqrt(x**2 + y**2) * np.pi / 2.0
463
+ x = np.sin(phi) * np.cos(theta)
464
+ y = np.sin(phi) * np.sin(theta)
465
+ z = np.cos(phi)
466
+ undistorted = np.hstack((x[:, None], y[:, None], z[:, None]))
467
+ undistorted[mask, 0] = -undistorted[mask, 0]
468
+ undistorted[mask, 2] = -undistorted[mask, 2]
469
+
470
+ x, y, z = undistorted[:, 0], undistorted[:, 1], undistorted[:, 2]
471
+ lat = np.pi - np.arccos(y)
472
+ lon = np.pi - np.arctan2(z, x)
473
+ u = width * lon / (2 * np.pi)
474
+ v = height * lat / np.pi
475
+ u = np.mod(u, width)
476
+ v = np.mod(v, height)
477
+ distorted = np.stack([u, v], axis=-1)
478
+
479
+ return distorted
480
+
481
+ def undistort_point(self, distorted):
482
+ """
483
+ Undistort point
484
+
485
+ Args:
486
+ distorted (np.ndarray): distorted points, shape [N, 2]
487
+
488
+ Returns:
489
+ undistorted (np.ndarray): undistorted points, shape [N, 3]
490
+ """
491
+ distorted = distorted.copy().astype(np.float64)
492
+ width, height = self.image_size
493
+
494
+ u, v = distorted[:, 0], distorted[:, 1]
495
+ lon = np.pi - u / width * 2 * np.pi
496
+ lat = np.pi - v / height * np.pi
497
+ y = np.cos(lat)
498
+ x = np.sin(lat) * np.cos(lon)
499
+ z = np.sin(lat) * np.sin(lon)
500
+ undistorted = np.stack([x, y, z], axis=-1)
501
+
502
+ if self.coeff is not None:
503
+ K, D = self.coeff["K"], self.coeff["D"]
504
+ cx = cy = (height - 1.0) / 2.0
505
+
506
+ mask = undistorted[:, 2] < 0
507
+ undistorted[mask, 0] = -undistorted[mask, 0]
508
+ undistorted[mask, 2] = -undistorted[mask, 2]
509
+ x, y, z = undistorted[:, 0], undistorted[:, 1], undistorted[:, 2]
510
+ theta = np.arctan2(y, x)
511
+ phi = np.arccos(z)
512
+ r = phi * 2.0 / np.pi
513
+ x = r * np.cos(theta) * cx + cx
514
+ y = r * np.sin(theta) * cy + cy
515
+ distorted = np.hstack((x[:, None], y[:, None]))
516
+ undistorted = cv2.fisheye.undistortPoints(distorted.reshape(1, -1, 2), K, D).reshape(-1, 2)
517
+ undistorted = np.hstack((undistorted, np.ones((undistorted.shape[0], 1), np.float64)))
518
+ undistorted = undistorted / np.linalg.norm(undistorted, axis=1, keepdims=True)
519
+ undistorted[mask, 0] = -undistorted[mask, 0]
520
+ undistorted[mask, 2] = -undistorted[mask, 2]
521
+
522
+ return undistorted
523
+
524
+ def interp_line(self, lines, num=None, resolution=0.01):
525
+ """
526
+ Interpolate line
527
+
528
+ Args:
529
+ lines (np.ndarray): lines, shape [N, 2, 2]
530
+ num (int | None): number of interpolated points per line
531
+ resolution (float): resolution of interpolation
532
+
533
+ Returns:
534
+ pts_list (list): list of interpolated points
535
+ """
536
+ distorted = lines.reshape(-1, 2)
537
+ undistorted = self.undistort_point(distorted)
538
+ arcs = undistorted.reshape(-1, 2, 3)
539
+ undistorted_list = self.interp_arc(arcs, num, resolution)
540
+ distorted_list = []
541
+ for undistorted in undistorted_list:
542
+ distorted = self.distort_point(undistorted)
543
+ distorted_list.append(distorted)
544
+
545
+ return distorted_list
546
+
547
+ def insert_line(self, image, lines, color, thickness=1):
548
+ """
549
+ Insert line
550
+
551
+ Args:
552
+ image (np.ndarray): image
553
+ lines (np.ndarray): lines, shape [N, 2, 2]
554
+ color (tuple): color
555
+ thickness (int): thickness
556
+
557
+ Returns:
558
+ image (np.ndarray): image
559
+ """
560
+ pts_list = self.interp_line(lines)
561
+ super().insert_line(image, pts_list, color, thickness)
562
+
563
+ return image
564
+
565
+ def truncate_line(self, lines):
566
+ """
567
+ Truncate line
568
+
569
+ Args:
570
+ lines (np.ndarray): lines, shape [N, 2, 2]
571
+ image_size (tuple): image size, [width, height]
572
+
573
+ Returns:
574
+ lines (np.ndarray): truncated lines, shape [M, 2, 2]
575
+ """
576
+ width = self.image_size[0]
577
+ pts_list = self.interp_line(lines)
578
+ lines = []
579
+ for pts in pts_list:
580
+ dx = abs(pts[:-1, 0] - pts[1:, 0])
581
+ mask = dx > width / 2.0
582
+ s = sum(mask)
583
+ assert s <= 1
584
+ if s == 0:
585
+ lines.append([pts[0], pts[-1]])
586
+ else:
587
+ ind = np.where(mask)[0][0]
588
+ lines.append([pts[0], pts[ind]])
589
+ lines.append([pts[ind + 1], pts[-1]])
590
+ lines = np.asarray(lines)
591
+
592
+ return lines