File size: 7,076 Bytes
2402804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
from unittest import TestCase
import cv2
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData
from mmdet.evaluation import INSTANCE_OFFSET
from mmdet.structures import DetDataSample
from mmdet.visualization import DetLocalVisualizer, TrackLocalVisualizer
def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
return bboxes
def _create_panoptic_data(num_boxes, h, w):
sem_seg = np.zeros((h, w), dtype=np.int64) + 2
bboxes = _rand_bboxes(num_boxes, h, w).int()
labels = torch.randint(2, (num_boxes, ))
for i in range(num_boxes):
x, y, w, h = bboxes[i]
sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i]
return sem_seg[None]
class TestDetLocalVisualizer(TestCase):
def test_add_datasample(self):
h = 12
w = 10
num_class = 3
num_bboxes = 5
out_file = 'out_file.jpg'
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
# test gt_instances
gt_instances = InstanceData()
gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
gt_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
det_data_sample = DetDataSample()
det_data_sample.gt_instances = gt_instances
det_local_visualizer = DetLocalVisualizer()
det_local_visualizer.add_datasample(
'image', image, det_data_sample, draw_pred=False)
# test out_file
det_local_visualizer.add_datasample(
'image',
image,
det_data_sample,
draw_pred=False,
out_file=out_file)
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == (h, w, 3)
os.remove(out_file)
# test gt_instances and pred_instances
pred_instances = InstanceData()
pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
pred_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
pred_instances.scores = torch.rand((num_bboxes, ))
det_data_sample.pred_instances = pred_instances
det_local_visualizer.add_datasample(
'image', image, det_data_sample, out_file=out_file)
self._assert_image_and_shape(out_file, (h, w * 2, 3))
det_local_visualizer.add_datasample(
'image', image, det_data_sample, draw_gt=False, out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, 3))
det_local_visualizer.add_datasample(
'image',
image,
det_data_sample,
draw_pred=False,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, 3))
# test gt_panoptic_seg and pred_panoptic_seg
det_local_visualizer.dataset_meta = dict(classes=('1', '2'))
gt_sem_seg = _create_panoptic_data(num_bboxes, h, w)
panoptic_seg = PixelData(sem_seg=gt_sem_seg)
det_data_sample = DetDataSample()
det_data_sample.gt_panoptic_seg = panoptic_seg
pred_sem_seg = _create_panoptic_data(num_bboxes, h, w)
panoptic_seg = PixelData(sem_seg=pred_sem_seg)
det_data_sample.pred_panoptic_seg = panoptic_seg
det_local_visualizer.add_datasample(
'image', image, det_data_sample, out_file=out_file)
self._assert_image_and_shape(out_file, (h, w * 2, 3))
# class information must be provided
det_local_visualizer.dataset_meta = {}
with self.assertRaises(AssertionError):
det_local_visualizer.add_datasample(
'image', image, det_data_sample, out_file=out_file)
def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
os.remove(out_file)
class TestTrackLocalVisualizer(TestCase):
@staticmethod
def _get_gt_instances():
bboxes = np.array([[912, 484, 1009, 593], [1338, 418, 1505, 797]])
masks = np.zeros((2, 1080, 1920), dtype=np.bool_)
for i, bbox in enumerate(bboxes):
masks[i, bbox[1]:bbox[3], bbox[0]:bbox[2]] = True
instances_data = dict(
bboxes=torch.tensor(bboxes),
masks=masks,
instances_id=torch.tensor([1, 2]),
labels=torch.tensor([0, 1]))
instances = InstanceData(**instances_data)
return instances
@staticmethod
def _get_pred_instances():
instances_data = dict(
bboxes=torch.tensor([[900, 500, 1000, 600], [1300, 400, 1500,
800]]),
instances_id=torch.tensor([1, 2]),
labels=torch.tensor([0, 1]),
scores=torch.tensor([0.955, 0.876]))
instances = InstanceData(**instances_data)
return instances
@staticmethod
def _assert_image_and_shape(out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
os.remove(out_file)
def test_add_datasample(self):
out_file = 'out_file.jpg'
h, w = 1080, 1920
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
gt_instances = self._get_gt_instances()
pred_instances = self._get_pred_instances()
image_data_sample = DetDataSample()
image_data_sample.gt_instances = gt_instances
image_data_sample.pred_track_instances = pred_instances
track_local_visualizer = TrackLocalVisualizer(alpha=0.2)
track_local_visualizer.dataset_meta = dict(
classes=['pedestrian', 'vehicle'])
# test gt_instances
track_local_visualizer.add_datasample('image', image,
image_data_sample, None)
# test out_file
track_local_visualizer.add_datasample(
'image', image, image_data_sample, None, out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, 3))
# test gt_instances and pred_instances
track_local_visualizer.add_datasample(
'image', image, image_data_sample, out_file=out_file)
self._assert_image_and_shape(out_file, (h, 2 * w, 3))
track_local_visualizer.add_datasample(
'image',
image,
image_data_sample,
draw_gt=False,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, 3))
track_local_visualizer.add_datasample(
'image',
image,
image_data_sample,
draw_pred=False,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, 3))
|