zhouyik commited on
Commit
b7c995c
Β·
verified Β·
1 Parent(s): 78bd293

Upload ./visualizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. visualizer.py +775 -0
visualizer.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://huggingface.co/spaces/PolyU-ChenLab/UniPixel/blob/main/unipixel/utils/visualizer.py
2
+
3
+ import colorsys
4
+ import io
5
+ import math
6
+ import random
7
+ from enum import Enum, unique
8
+
9
+ import cv2
10
+ import imageio.v3 as iio
11
+ import matplotlib as mpl
12
+ import matplotlib.colors as mplc
13
+ import matplotlib.figure as mplfigure
14
+ import numpy as np
15
+ import pycocotools.mask as mask_util
16
+ import torch
17
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
18
+
19
+ _SMALL_OBJECT_AREA_THRESH = 1000
20
+ _LARGE_MASK_AREA_THRESH = 120000
21
+
22
+ _COLORS = np.array([
23
+ 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556, 0.466, 0.674, 0.188, 0.301,
24
+ 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500,
25
+ 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000,
26
+ 0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, 0.667, 1.000, 0.000, 1.000,
27
+ 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
28
+ 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500,
29
+ 0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333, 0.500, 1.000,
30
+ 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000,
31
+ 1.000, 0.333, 0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, 1.000,
32
+ 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.333,
33
+ 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167,
34
+ 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000,
35
+ 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000,
36
+ 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
37
+ ]).astype(np.float32).reshape(-1, 3)
38
+
39
+
40
+ def random_color(rgb=False, maximum=1):
41
+ idx = np.random.randint(0, len(_COLORS))
42
+ ret = _COLORS[idx] * maximum
43
+ if not rgb:
44
+ ret = ret[::-1]
45
+ return ret
46
+
47
+
48
+ def sample_color(rgb=False, maximum=1):
49
+ inds = list(range(len(_COLORS)))
50
+ random.shuffle(inds)
51
+ ret = _COLORS[inds] * maximum
52
+ if not rgb:
53
+ ret = ret[::-1]
54
+ return ret
55
+
56
+
57
+ @unique
58
+ class ColorMode(Enum):
59
+ """
60
+ Enum of different color modes to use for instance visualizations.
61
+ """
62
+
63
+ IMAGE = 0
64
+ """
65
+ Picks a random color for every instance and overlay segmentations with low opacity.
66
+ """
67
+ SEGMENTATION = 1
68
+ """
69
+ Let instances of the same category have similar colors
70
+ (from metadata.thing_colors), and overlay them with
71
+ high opacity. This provides more attention on the quality of segmentation.
72
+ """
73
+ IMAGE_BW = 2
74
+ """
75
+ Same as IMAGE, but convert all areas without masks to gray-scale.
76
+ Only available for drawing per-instance mask predictions.
77
+ """
78
+
79
+
80
+ class GenericMask:
81
+ """
82
+ Attribute:
83
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
84
+ Each ndarray has format [x, y, x, y, ...]
85
+ mask (ndarray): a binary mask
86
+ """
87
+
88
+ def __init__(self, mask_or_polygons, height, width):
89
+ self._mask = self._polygons = self._has_holes = None
90
+ self.height = height
91
+ self.width = width
92
+
93
+ m = mask_or_polygons
94
+ if isinstance(m, dict):
95
+ # RLEs
96
+ assert "counts" in m and "size" in m
97
+ if isinstance(m["counts"], list): # uncompressed RLEs
98
+ h, w = m["size"]
99
+ assert h == height and w == width
100
+ m = mask_util.frPyObjects(m, h, w)
101
+ self._mask = mask_util.decode(m)[:, :]
102
+ return
103
+
104
+ if isinstance(m, list): # list[ndarray]
105
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
106
+ return
107
+
108
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
109
+ assert m.shape[1] != 2, m.shape
110
+ assert m.shape == (
111
+ height,
112
+ width,
113
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
114
+ self._mask = m.astype("uint8")
115
+ return
116
+
117
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
118
+
119
+ @property
120
+ def mask(self):
121
+ if self._mask is None:
122
+ self._mask = self.polygons_to_mask(self._polygons)
123
+ return self._mask
124
+
125
+ @property
126
+ def polygons(self):
127
+ if self._polygons is None:
128
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
129
+ return self._polygons
130
+
131
+ @property
132
+ def has_holes(self):
133
+ if self._has_holes is None:
134
+ if self._mask is not None:
135
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
136
+ else:
137
+ self._has_holes = False # if original format is polygon, does not have holes
138
+ return self._has_holes
139
+
140
+ def mask_to_polygons(self, mask):
141
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
142
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
143
+ # Internal contours (holes) are placed in hierarchy-2.
144
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
145
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
146
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
147
+ hierarchy = res[-1]
148
+ if hierarchy is None: # empty mask
149
+ return [], False
150
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
151
+ res = res[-2]
152
+ res = [x.flatten() for x in res]
153
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
154
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
155
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
156
+ res = [x + 0.5 for x in res if len(x) >= 6]
157
+ return res, has_holes
158
+
159
+ def polygons_to_mask(self, polygons):
160
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
161
+ rle = mask_util.merge(rle)
162
+ return mask_util.decode(rle)[:, :]
163
+
164
+ def area(self):
165
+ return self.mask.sum()
166
+
167
+ def bbox(self):
168
+
169
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
170
+ p = mask_util.merge(p)
171
+ bbox = mask_util.toBbox(p)
172
+ bbox[2] += bbox[0]
173
+ bbox[3] += bbox[1]
174
+ return bbox
175
+
176
+
177
+ class VisImage:
178
+
179
+ def __init__(self, img, scale=1.0):
180
+ """
181
+ Args:
182
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
183
+ scale (float): scale the input image
184
+ """
185
+ self.img = img
186
+ self.scale = scale
187
+ self.width, self.height = img.shape[1], img.shape[0]
188
+ self._setup_figure(img)
189
+
190
+ def _setup_figure(self, img):
191
+ """
192
+ Args:
193
+ Same as in :meth:`__init__()`.
194
+ Returns:
195
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
196
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
197
+ """
198
+ fig = mplfigure.Figure(frameon=False)
199
+ self.dpi = fig.get_dpi()
200
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
201
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
202
+ fig.set_size_inches(
203
+ (self.width * self.scale + 1e-2) / self.dpi,
204
+ (self.height * self.scale + 1e-2) / self.dpi,
205
+ )
206
+ self.canvas = FigureCanvasAgg(fig)
207
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
208
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
209
+ ax.axis("off")
210
+ self.fig = fig
211
+ self.ax = ax
212
+ self.reset_image(img)
213
+
214
+ def reset_image(self, img):
215
+ """
216
+ Args:
217
+ img: same as in __init__
218
+ """
219
+ img = img.astype("uint8")
220
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
221
+
222
+ def save(self, filepath, fig_format=None):
223
+ """
224
+ Args:
225
+ filepath (str): a string that contains the absolute path, including the file name, where
226
+ the visualized image will be saved.
227
+ """
228
+ if fig_format is not None:
229
+ self.fig.savefig(filepath, format=fig_format)
230
+ else:
231
+ self.fig.savefig(filepath)
232
+
233
+ def get_image(self):
234
+ """
235
+ Returns:
236
+ ndarray:
237
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
238
+ The shape is scaled w.r.t the input image using the given `scale` argument.
239
+ """
240
+ canvas = self.canvas
241
+ s, (width, height) = canvas.print_to_buffer()
242
+ # buf = io.BytesIO() # works for cairo backend
243
+ # canvas.print_rgba(buf)
244
+ # width, height = self.width, self.height
245
+ # s = buf.getvalue()
246
+
247
+ buffer = np.frombuffer(s, dtype="uint8")
248
+
249
+ img_rgba = buffer.reshape(height, width, 4)
250
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
251
+ return rgb.astype("uint8")
252
+
253
+
254
+ class Visualizer:
255
+ """
256
+ Visualizer that draws data about detection/segmentation on images.
257
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
258
+ that draw primitive objects to images, as well as high-level wrappers like
259
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
260
+ that draw composite data in some pre-defined style.
261
+ Note that the exact visualization style for the high-level wrappers are subject to change.
262
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
263
+ of objects themselves (e.g. when the object is too small) may change according
264
+ to different heuristics, as long as the results still look visually reasonable.
265
+ To obtain a consistent style, you can implement custom drawing functions with the
266
+ abovementioned primitive methods instead. If you need more customized visualization
267
+ styles, you can process the data yourself following their format documented in
268
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
269
+ intend to satisfy everyone's preference on drawing styles.
270
+ This visualizer focuses on high rendering quality rather than performance. It is not
271
+ designed to be used for real-time applications.
272
+ """
273
+
274
+ def __init__(self, img_rgb, scale=1.0, instance_mode=ColorMode.IMAGE):
275
+ """
276
+ Args:
277
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
278
+ the height and width of the image respectively. C is the number of
279
+ color channels. The image is required to be in RGB format since that
280
+ is a requirement of the Matplotlib library. The image is also expected
281
+ to be in the range [0, 255].
282
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
283
+ instances on an image.
284
+ """
285
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
286
+ self.output = VisImage(self.img, scale=scale)
287
+ self.cpu_device = torch.device("cpu")
288
+
289
+ # too small texts are useless, therefore clamp to 9
290
+ self._default_font_size = max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
291
+ self._default_font_size = 18
292
+ self._instance_mode = instance_mode
293
+
294
+ import matplotlib.colors as mcolors
295
+ css4_colors = mcolors.CSS4_COLORS
296
+ self.color_proposals = [list(mcolors.hex2color(color)) for color in css4_colors.values()]
297
+
298
+ def draw_text(
299
+ self,
300
+ text,
301
+ position,
302
+ *,
303
+ font_size=None,
304
+ color="g",
305
+ horizontal_alignment="center",
306
+ rotation=0,
307
+ ):
308
+ """
309
+ Args:
310
+ text (str): class label
311
+ position (tuple): a tuple of the x and y coordinates to place text on image.
312
+ font_size (int, optional): font of the text. If not provided, a font size
313
+ proportional to the image width is calculated and used.
314
+ color: color of the text. Refer to `matplotlib.colors` for full list
315
+ of formats that are accepted.
316
+ horizontal_alignment (str): see `matplotlib.text.Text`
317
+ rotation: rotation angle in degrees CCW
318
+ Returns:
319
+ output (VisImage): image object with text drawn.
320
+ """
321
+ if not font_size:
322
+ font_size = self._default_font_size
323
+
324
+ # since the text background is dark, we don't want the text to be dark
325
+ color = np.maximum(list(mplc.to_rgb(color)), 0.15)
326
+ color[np.argmax(color)] = max(0.8, np.max(color))
327
+
328
+ def contrasting_color(rgb):
329
+ """Returns 'white' or 'black' depending on which color contrasts more with the given RGB value."""
330
+
331
+ # Decompose the RGB tuple
332
+ R, G, B = rgb
333
+
334
+ # Calculate the Y value
335
+ Y = 0.299 * R + 0.587 * G + 0.114 * B
336
+
337
+ # If Y value is greater than 128, it's closer to white so return black. Otherwise, return white.
338
+ return 'black' if Y > 128 else 'white'
339
+
340
+ bbox_background = contrasting_color(color * 255)
341
+
342
+ x, y = position
343
+ self.output.ax.text(
344
+ x,
345
+ y,
346
+ text,
347
+ size=font_size * self.output.scale,
348
+ family="sans-serif",
349
+ bbox={
350
+ "facecolor": bbox_background,
351
+ "alpha": 0.8,
352
+ "pad": 0.7,
353
+ "edgecolor": "none"
354
+ },
355
+ verticalalignment="top",
356
+ horizontalalignment=horizontal_alignment,
357
+ color=color,
358
+ zorder=10,
359
+ rotation=rotation,
360
+ )
361
+ return self.output
362
+
363
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
364
+ """
365
+ Args:
366
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
367
+ are the coordinates of the image's top left corner. x1 and y1 are the
368
+ coordinates of the image's bottom right corner.
369
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
370
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
371
+ for full list of formats that are accepted.
372
+ line_style (string): the string to use to create the outline of the boxes.
373
+ Returns:
374
+ output (VisImage): image object with box drawn.
375
+ """
376
+ x0, y0, x1, y1 = box_coord
377
+ width = x1 - x0
378
+ height = y1 - y0
379
+
380
+ linewidth = max(self._default_font_size / 12, 1)
381
+
382
+ self.output.ax.add_patch(
383
+ mpl.patches.Rectangle(
384
+ (x0, y0),
385
+ width,
386
+ height,
387
+ fill=False,
388
+ edgecolor=edge_color,
389
+ linewidth=linewidth * self.output.scale,
390
+ alpha=alpha,
391
+ linestyle=line_style,
392
+ ))
393
+ return self.output
394
+
395
+ def draw_rotated_box_with_label(self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None):
396
+ """
397
+ Draw a rotated box with label on its top-left corner.
398
+ Args:
399
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
400
+ where cnt_x and cnt_y are the center coordinates of the box.
401
+ w and h are the width and height of the box. angle represents how
402
+ many degrees the box is rotated CCW with regard to the 0-degree box.
403
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
404
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
405
+ for full list of formats that are accepted.
406
+ line_style (string): the string to use to create the outline of the boxes.
407
+ label (string): label for rotated box. It will not be rendered when set to None.
408
+ Returns:
409
+ output (VisImage): image object with box drawn.
410
+ """
411
+ cnt_x, cnt_y, w, h, angle = rotated_box
412
+ area = w * h
413
+ # use thinner lines when the box is small
414
+ linewidth = self._default_font_size / (6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3)
415
+
416
+ theta = angle * math.pi / 180.0
417
+ c = math.cos(theta)
418
+ s = math.sin(theta)
419
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
420
+ # x: left->right ; y: top->down
421
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
422
+ for k in range(4):
423
+ j = (k + 1) % 4
424
+ self.draw_line(
425
+ [rotated_rect[k][0], rotated_rect[j][0]],
426
+ [rotated_rect[k][1], rotated_rect[j][1]],
427
+ color=edge_color,
428
+ linestyle="--" if k == 1 else line_style,
429
+ linewidth=linewidth,
430
+ )
431
+
432
+ if label is not None:
433
+ text_pos = rotated_rect[1] # topleft corner
434
+
435
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
436
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
437
+ font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size)
438
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
439
+
440
+ return self.output
441
+
442
+ def draw_circle(self, circle_coord, color, radius=3):
443
+ """
444
+ Args:
445
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
446
+ of the center of the circle.
447
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
448
+ formats that are accepted.
449
+ radius (int): radius of the circle.
450
+ Returns:
451
+ output (VisImage): image object with box drawn.
452
+ """
453
+ x, y = circle_coord
454
+ self.output.ax.add_patch(mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color))
455
+ return self.output
456
+
457
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
458
+ """
459
+ Args:
460
+ x_data (list[int]): a list containing x values of all the points being drawn.
461
+ Length of list should match the length of y_data.
462
+ y_data (list[int]): a list containing y values of all the points being drawn.
463
+ Length of list should match the length of x_data.
464
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
465
+ formats that are accepted.
466
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
467
+ for a full list of formats that are accepted.
468
+ linewidth (float or None): width of the line. When it's None,
469
+ a default value will be computed and used.
470
+ Returns:
471
+ output (VisImage): image object with line drawn.
472
+ """
473
+ if linewidth is None:
474
+ linewidth = self._default_font_size / 3
475
+ linewidth = max(linewidth, 1)
476
+ self.output.ax.add_line(
477
+ mpl.lines.Line2D(
478
+ x_data,
479
+ y_data,
480
+ linewidth=linewidth * self.output.scale,
481
+ color=color,
482
+ linestyle=linestyle,
483
+ ))
484
+ return self.output
485
+
486
+ def draw_binary_mask(self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10):
487
+ """
488
+ Args:
489
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
490
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
491
+ type.
492
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
493
+ formats that are accepted. If None, will pick a random color.
494
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
495
+ full list of formats that are accepted.
496
+ text (str): if None, will be drawn on the object
497
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
498
+ area_threshold (float): a connected component smaller than this area will not be shown.
499
+ Returns:
500
+ output (VisImage): image object with mask drawn.
501
+ """
502
+ if color is None:
503
+ color = random_color(rgb=True, maximum=1)
504
+ color = mplc.to_rgb(color)
505
+
506
+ has_valid_segment = False
507
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
508
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
509
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
510
+
511
+ if not mask.has_holes:
512
+ # draw polygons for regular masks
513
+ for segment in mask.polygons:
514
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
515
+ if area < (area_threshold or 0):
516
+ continue
517
+ has_valid_segment = True
518
+ segment = segment.reshape(-1, 2)
519
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
520
+ else:
521
+ # Use Path/PathPatch to draw vector graphics:
522
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
523
+ # rgba = np.zeros(shape2d + (4,), dtype="float32")
524
+ # rgba[:, :, :3] = color
525
+ # rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
526
+ # has_valid_segment = True
527
+ # self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
528
+ print('has hole')
529
+ for segment in mask.polygons:
530
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
531
+ if area < (area_threshold or 0):
532
+ continue
533
+ has_valid_segment = True
534
+ segment = segment.reshape(-1, 2)
535
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
536
+
537
+ if text is not None and has_valid_segment:
538
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
539
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
540
+ return self.output
541
+
542
+ def _draw_number_in_mask(self, binary_mask, text, color, label_mode='1'):
543
+ """
544
+ Find proper places to draw text given a binary mask.
545
+ """
546
+
547
+ def number_to_string(n):
548
+ chars = []
549
+ while n:
550
+ n, remainder = divmod(n - 1, 26)
551
+ chars.append(chr(97 + remainder))
552
+ return ''.join(reversed(chars))
553
+
554
+ binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), 'constant')
555
+ mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0)
556
+ mask_dt = mask_dt[1:-1, 1:-1]
557
+ max_dist = np.max(mask_dt)
558
+ coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x]
559
+
560
+ if label_mode == 'a':
561
+ text = number_to_string(int(text))
562
+ else:
563
+ text = text
564
+
565
+ self.draw_text(text, (coords_x[len(coords_x) // 2] + 2, coords_y[len(coords_y) // 2] - 6), color=color)
566
+
567
+ def draw_binary_mask_with_number(self,
568
+ binary_mask,
569
+ color=None,
570
+ *,
571
+ edge_color=None,
572
+ text=None,
573
+ label_mode='1',
574
+ alpha=0.1,
575
+ anno_mode=['Mask'],
576
+ area_threshold=10):
577
+ """
578
+ Args:
579
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
580
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
581
+ type.
582
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
583
+ formats that are accepted. If None, will pick a random color.
584
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
585
+ full list of formats that are accepted.
586
+ text (str): if None, will be drawn on the object
587
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
588
+ area_threshold (float): a connected component smaller than this area will not be shown.
589
+ Returns:
590
+ output (VisImage): image object with mask drawn.
591
+ """
592
+ if color is None:
593
+ randint = random.randint(0, len(self.color_proposals) - 1)
594
+ color = self.color_proposals[randint]
595
+ color = mplc.to_rgb(color)
596
+
597
+ has_valid_segment = True
598
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
599
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
600
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
601
+
602
+ if 'Mask' in anno_mode:
603
+ if not mask.has_holes:
604
+ # draw polygons for regular masks
605
+ for segment in mask.polygons:
606
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
607
+ if area < (area_threshold or 0):
608
+ continue
609
+ has_valid_segment = True
610
+ segment = segment.reshape(-1, 2)
611
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
612
+ else:
613
+ # Use Path/PathPatch to draw vector graphics:
614
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
615
+ for segment in mask.polygons:
616
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
617
+ if area < (area_threshold or 0):
618
+ continue
619
+ has_valid_segment = True
620
+ segment = segment.reshape(-1, 2)
621
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
622
+ # rgba = np.zeros(shape2d + (4,), dtype="float32")
623
+ # rgba[:, :, :3] = color
624
+ # rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
625
+ # self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
626
+
627
+ if 'Box' in anno_mode:
628
+ bbox = mask.bbox()
629
+ self.draw_box(bbox, edge_color=color, alpha=0.75)
630
+
631
+ if 'Mark' in anno_mode:
632
+ has_valid_segment = True
633
+ else:
634
+ has_valid_segment = False
635
+
636
+ if text is not None and has_valid_segment:
637
+ # lighter_color = tuple([x*0.2 for x in color])
638
+ lighter_color = [1, 1, 1] # self._change_color_brightness(color, brightness_factor=0.7)
639
+ self._draw_number_in_mask(binary_mask, text, lighter_color, label_mode)
640
+ return self.output
641
+
642
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
643
+ """
644
+ Args:
645
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
646
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
647
+ formats that are accepted.
648
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
649
+ full list of formats that are accepted. If not provided, a darker shade
650
+ of the polygon color will be used instead.
651
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
652
+ Returns:
653
+ output (VisImage): image object with polygon drawn.
654
+ """
655
+ if edge_color is None:
656
+ # make edge color darker than the polygon color
657
+ if alpha > 0.8:
658
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
659
+ else:
660
+ edge_color = color
661
+ edge_color = mplc.to_rgb(edge_color) + (1, )
662
+
663
+ polygon = mpl.patches.Polygon(
664
+ segment,
665
+ fill=True,
666
+ facecolor=mplc.to_rgb(color) + (alpha, ),
667
+ edgecolor=edge_color,
668
+ linewidth=1, # max(self._default_font_size // 5 * self.output.scale, 1),
669
+ )
670
+ self.output.ax.add_patch(polygon)
671
+ return self.output
672
+
673
+ """
674
+ Internal methods:
675
+ """
676
+
677
+ def _jitter(self, color):
678
+ """
679
+ Randomly modifies given color to produce a slightly different color than the color given.
680
+ Args:
681
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
682
+ picked. The values in the list are in the [0.0, 1.0] range.
683
+ Returns:
684
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
685
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
686
+ """
687
+ color = mplc.to_rgb(color)
688
+ # np.random.seed(0)
689
+ vec = np.random.rand(3)
690
+ # better to do it in another color space
691
+ vec = vec / np.linalg.norm(vec) * 0.5
692
+ res = np.clip(vec + color, 0, 1)
693
+ return tuple(res)
694
+
695
+ def _create_grayscale_image(self, mask=None):
696
+ """
697
+ Create a grayscale version of the original image.
698
+ The colors in masked area, if given, will be kept.
699
+ """
700
+ img_bw = self.img.astype("f4").mean(axis=2)
701
+ img_bw = np.stack([img_bw] * 3, axis=2)
702
+ if mask is not None:
703
+ img_bw[mask] = self.img[mask]
704
+ return img_bw
705
+
706
+ def _change_color_brightness(self, color, brightness_factor):
707
+ """
708
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
709
+ less or more saturation than the original color.
710
+ Args:
711
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
712
+ formats that are accepted.
713
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
714
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
715
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
716
+ Returns:
717
+ modified_color (tuple[double]): a tuple containing the RGB values of the
718
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
719
+ """
720
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
721
+ color = mplc.to_rgb(color)
722
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
723
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
724
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
725
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
726
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
727
+ return modified_color
728
+
729
+ def _draw_text_in_mask(self, binary_mask, text, color):
730
+ """
731
+ Find proper places to draw text given a binary mask.
732
+ """
733
+ # sometimes drawn on wrong objects. the heuristics here can improve.
734
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
735
+ if stats[1:, -1].size == 0:
736
+ return
737
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
738
+
739
+ # draw text on the largest component, as well as other very large components.
740
+ for cid in range(1, _num_cc):
741
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
742
+ # median is more stable than centroid
743
+ # center = centroids[largest_component_id]
744
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
745
+ bottom = np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
746
+ center[1] = bottom[1] + 2
747
+ self.draw_text(text, center, color=color)
748
+
749
+ def get_output(self):
750
+ """
751
+ Returns:
752
+ output (VisImage): the image output containing the visualizations added
753
+ to the image.
754
+ """
755
+ return self.output
756
+
757
+
758
+ def draw_mask(frames, masks, colors=None):
759
+ if colors is None:
760
+ colors = [random_color(rgb=True, maximum=1) for _ in range(len(masks))]
761
+
762
+ imgs = []
763
+ for i in range(frames.size(0)):
764
+ vis = Visualizer(frames[i].numpy())
765
+
766
+ for j in range(len(masks)):
767
+ fig = vis.draw_binary_mask_with_number(masks[j][0, i].bool().numpy(), color=colors[j], alpha=0.3)
768
+
769
+ buffer = io.BytesIO()
770
+ fig.save(buffer)
771
+ buffer.seek(0)
772
+ img = iio.imread(buffer)
773
+ imgs.append(img)
774
+
775
+ return imgs