|
|
import os |
|
|
from unittest import TestCase |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.structures import InstanceData, PixelData |
|
|
|
|
|
from mmdet.evaluation import INSTANCE_OFFSET |
|
|
from mmdet.structures import DetDataSample |
|
|
from mmdet.visualization import DetLocalVisualizer, TrackLocalVisualizer |
|
|
|
|
|
|
|
|
def _rand_bboxes(num_boxes, h, w): |
|
|
cx, cy, bw, bh = torch.rand(num_boxes, 4).T |
|
|
|
|
|
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w) |
|
|
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h) |
|
|
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w) |
|
|
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h) |
|
|
|
|
|
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T |
|
|
return bboxes |
|
|
|
|
|
|
|
|
def _create_panoptic_data(num_boxes, h, w): |
|
|
sem_seg = np.zeros((h, w), dtype=np.int64) + 2 |
|
|
bboxes = _rand_bboxes(num_boxes, h, w).int() |
|
|
labels = torch.randint(2, (num_boxes, )) |
|
|
for i in range(num_boxes): |
|
|
x, y, w, h = bboxes[i] |
|
|
sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i] |
|
|
|
|
|
return sem_seg[None] |
|
|
|
|
|
|
|
|
class TestDetLocalVisualizer(TestCase): |
|
|
|
|
|
def test_add_datasample(self): |
|
|
h = 12 |
|
|
w = 10 |
|
|
num_class = 3 |
|
|
num_bboxes = 5 |
|
|
out_file = 'out_file.jpg' |
|
|
|
|
|
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') |
|
|
|
|
|
|
|
|
gt_instances = InstanceData() |
|
|
gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w) |
|
|
gt_instances.labels = torch.randint(0, num_class, (num_bboxes, )) |
|
|
det_data_sample = DetDataSample() |
|
|
det_data_sample.gt_instances = gt_instances |
|
|
|
|
|
det_local_visualizer = DetLocalVisualizer() |
|
|
det_local_visualizer.add_datasample( |
|
|
'image', image, det_data_sample, draw_pred=False) |
|
|
|
|
|
|
|
|
det_local_visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
det_data_sample, |
|
|
draw_pred=False, |
|
|
out_file=out_file) |
|
|
assert os.path.exists(out_file) |
|
|
drawn_img = cv2.imread(out_file) |
|
|
assert drawn_img.shape == (h, w, 3) |
|
|
os.remove(out_file) |
|
|
|
|
|
|
|
|
pred_instances = InstanceData() |
|
|
pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w) |
|
|
pred_instances.labels = torch.randint(0, num_class, (num_bboxes, )) |
|
|
pred_instances.scores = torch.rand((num_bboxes, )) |
|
|
det_data_sample.pred_instances = pred_instances |
|
|
|
|
|
det_local_visualizer.add_datasample( |
|
|
'image', image, det_data_sample, out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w * 2, 3)) |
|
|
|
|
|
det_local_visualizer.add_datasample( |
|
|
'image', image, det_data_sample, draw_gt=False, out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w, 3)) |
|
|
|
|
|
det_local_visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
det_data_sample, |
|
|
draw_pred=False, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w, 3)) |
|
|
|
|
|
|
|
|
det_local_visualizer.dataset_meta = dict(classes=('1', '2')) |
|
|
gt_sem_seg = _create_panoptic_data(num_bboxes, h, w) |
|
|
panoptic_seg = PixelData(sem_seg=gt_sem_seg) |
|
|
|
|
|
det_data_sample = DetDataSample() |
|
|
det_data_sample.gt_panoptic_seg = panoptic_seg |
|
|
|
|
|
pred_sem_seg = _create_panoptic_data(num_bboxes, h, w) |
|
|
panoptic_seg = PixelData(sem_seg=pred_sem_seg) |
|
|
det_data_sample.pred_panoptic_seg = panoptic_seg |
|
|
|
|
|
det_local_visualizer.add_datasample( |
|
|
'image', image, det_data_sample, out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w * 2, 3)) |
|
|
|
|
|
|
|
|
det_local_visualizer.dataset_meta = {} |
|
|
with self.assertRaises(AssertionError): |
|
|
det_local_visualizer.add_datasample( |
|
|
'image', image, det_data_sample, out_file=out_file) |
|
|
|
|
|
def _assert_image_and_shape(self, out_file, out_shape): |
|
|
assert os.path.exists(out_file) |
|
|
drawn_img = cv2.imread(out_file) |
|
|
assert drawn_img.shape == out_shape |
|
|
os.remove(out_file) |
|
|
|
|
|
|
|
|
class TestTrackLocalVisualizer(TestCase): |
|
|
|
|
|
@staticmethod |
|
|
def _get_gt_instances(): |
|
|
bboxes = np.array([[912, 484, 1009, 593], [1338, 418, 1505, 797]]) |
|
|
masks = np.zeros((2, 1080, 1920), dtype=np.bool_) |
|
|
for i, bbox in enumerate(bboxes): |
|
|
masks[i, bbox[1]:bbox[3], bbox[0]:bbox[2]] = True |
|
|
instances_data = dict( |
|
|
bboxes=torch.tensor(bboxes), |
|
|
masks=masks, |
|
|
instances_id=torch.tensor([1, 2]), |
|
|
labels=torch.tensor([0, 1])) |
|
|
instances = InstanceData(**instances_data) |
|
|
return instances |
|
|
|
|
|
@staticmethod |
|
|
def _get_pred_instances(): |
|
|
instances_data = dict( |
|
|
bboxes=torch.tensor([[900, 500, 1000, 600], [1300, 400, 1500, |
|
|
800]]), |
|
|
instances_id=torch.tensor([1, 2]), |
|
|
labels=torch.tensor([0, 1]), |
|
|
scores=torch.tensor([0.955, 0.876])) |
|
|
instances = InstanceData(**instances_data) |
|
|
return instances |
|
|
|
|
|
@staticmethod |
|
|
def _assert_image_and_shape(out_file, out_shape): |
|
|
assert os.path.exists(out_file) |
|
|
drawn_img = cv2.imread(out_file) |
|
|
assert drawn_img.shape == out_shape |
|
|
os.remove(out_file) |
|
|
|
|
|
def test_add_datasample(self): |
|
|
out_file = 'out_file.jpg' |
|
|
h, w = 1080, 1920 |
|
|
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') |
|
|
gt_instances = self._get_gt_instances() |
|
|
pred_instances = self._get_pred_instances() |
|
|
image_data_sample = DetDataSample() |
|
|
image_data_sample.gt_instances = gt_instances |
|
|
image_data_sample.pred_track_instances = pred_instances |
|
|
|
|
|
track_local_visualizer = TrackLocalVisualizer(alpha=0.2) |
|
|
track_local_visualizer.dataset_meta = dict( |
|
|
classes=['pedestrian', 'vehicle']) |
|
|
|
|
|
|
|
|
track_local_visualizer.add_datasample('image', image, |
|
|
image_data_sample, None) |
|
|
|
|
|
|
|
|
track_local_visualizer.add_datasample( |
|
|
'image', image, image_data_sample, None, out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w, 3)) |
|
|
|
|
|
|
|
|
track_local_visualizer.add_datasample( |
|
|
'image', image, image_data_sample, out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, 2 * w, 3)) |
|
|
|
|
|
track_local_visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
image_data_sample, |
|
|
draw_gt=False, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w, 3)) |
|
|
|
|
|
track_local_visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
image_data_sample, |
|
|
draw_pred=False, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w, 3)) |
|
|
|