|
|
|
|
|
import os |
|
|
import os.path as osp |
|
|
import tempfile |
|
|
from unittest import TestCase |
|
|
|
|
|
import cv2 |
|
|
import mmcv |
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.structures import PixelData |
|
|
|
|
|
from mmseg.structures import SegDataSample |
|
|
from mmseg.visualization import SegLocalVisualizer |
|
|
|
|
|
|
|
|
class TestSegLocalVisualizer(TestCase): |
|
|
|
|
|
def test_add_datasample(self): |
|
|
h = 10 |
|
|
w = 12 |
|
|
num_class = 2 |
|
|
out_file = 'out_file' |
|
|
|
|
|
image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') |
|
|
|
|
|
|
|
|
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) |
|
|
gt_sem_seg = PixelData(**gt_sem_seg_data) |
|
|
|
|
|
def test_add_datasample_forward(gt_sem_seg): |
|
|
data_sample = SegDataSample() |
|
|
data_sample.gt_sem_seg = gt_sem_seg |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
seg_local_visualizer = SegLocalVisualizer( |
|
|
vis_backends=[dict(type='LocalVisBackend')], |
|
|
save_dir=tmp_dir) |
|
|
seg_local_visualizer.dataset_meta = dict( |
|
|
classes=('background', 'foreground'), |
|
|
palette=[[120, 120, 120], [6, 230, 230]]) |
|
|
|
|
|
|
|
|
seg_local_visualizer.add_datasample(out_file, image, |
|
|
data_sample) |
|
|
|
|
|
assert os.path.exists( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png')) |
|
|
drawn_img = cv2.imread( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png')) |
|
|
assert drawn_img.shape == (h, w, 3) |
|
|
|
|
|
|
|
|
pred_sem_seg_data = dict( |
|
|
data=torch.randint(0, num_class, (1, h, w))) |
|
|
pred_sem_seg = PixelData(**pred_sem_seg_data) |
|
|
|
|
|
data_sample.pred_sem_seg = pred_sem_seg |
|
|
|
|
|
seg_local_visualizer.add_datasample(out_file, image, |
|
|
data_sample) |
|
|
self._assert_image_and_shape( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png'), (h, w * 2, 3)) |
|
|
|
|
|
seg_local_visualizer.add_datasample( |
|
|
out_file, image, data_sample, draw_gt=False) |
|
|
self._assert_image_and_shape( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png'), (h, w, 3)) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
test_add_datasample_forward(gt_sem_seg.cuda()) |
|
|
test_add_datasample_forward(gt_sem_seg) |
|
|
|
|
|
def test_cityscapes_add_datasample(self): |
|
|
h = 128 |
|
|
w = 256 |
|
|
num_class = 19 |
|
|
out_file = 'out_file_cityscapes' |
|
|
|
|
|
image = mmcv.imread( |
|
|
osp.join( |
|
|
osp.dirname(__file__), |
|
|
'../data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' |
|
|
), |
|
|
'color') |
|
|
sem_seg = mmcv.imread( |
|
|
osp.join( |
|
|
osp.dirname(__file__), |
|
|
'../data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' |
|
|
), |
|
|
'unchanged') |
|
|
sem_seg = torch.unsqueeze(torch.from_numpy(sem_seg), 0) |
|
|
gt_sem_seg_data = dict(data=sem_seg) |
|
|
gt_sem_seg = PixelData(**gt_sem_seg_data) |
|
|
|
|
|
def test_cityscapes_add_datasample_forward(gt_sem_seg): |
|
|
data_sample = SegDataSample() |
|
|
data_sample.gt_sem_seg = gt_sem_seg |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
seg_local_visualizer = SegLocalVisualizer( |
|
|
vis_backends=[dict(type='LocalVisBackend')], |
|
|
save_dir=tmp_dir) |
|
|
seg_local_visualizer.dataset_meta = dict( |
|
|
classes=('road', 'sidewalk', 'building', 'wall', 'fence', |
|
|
'pole', 'traffic light', 'traffic sign', |
|
|
'vegetation', 'terrain', 'sky', 'person', 'rider', |
|
|
'car', 'truck', 'bus', 'train', 'motorcycle', |
|
|
'bicycle'), |
|
|
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], |
|
|
[102, 102, 156], [190, 153, 153], [153, 153, 153], |
|
|
[250, 170, 30], [220, 220, 0], [107, 142, 35], |
|
|
[152, 251, 152], [70, 130, 180], [220, 20, 60], |
|
|
[255, 0, 0], [0, 0, 142], [0, 0, 70], |
|
|
[0, 60, 100], [0, 80, 100], [0, 0, 230], |
|
|
[119, 11, 32]]) |
|
|
|
|
|
seg_local_visualizer.add_datasample( |
|
|
out_file, |
|
|
image, |
|
|
data_sample, |
|
|
out_file=osp.join(tmp_dir, 'test.png')) |
|
|
self._assert_image_and_shape( |
|
|
osp.join(tmp_dir, 'test.png'), (h, w, 3)) |
|
|
|
|
|
|
|
|
pred_sem_seg_data = dict( |
|
|
data=torch.randint(0, num_class, (1, h, w))) |
|
|
pred_sem_seg = PixelData(**pred_sem_seg_data) |
|
|
|
|
|
data_sample.pred_sem_seg = pred_sem_seg |
|
|
|
|
|
|
|
|
seg_local_visualizer.add_datasample(out_file, image, |
|
|
data_sample) |
|
|
self._assert_image_and_shape( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png'), (h, w * 2, 3)) |
|
|
|
|
|
seg_local_visualizer.add_datasample( |
|
|
out_file, image, data_sample, draw_gt=False) |
|
|
self._assert_image_and_shape( |
|
|
osp.join(tmp_dir, 'vis_data', 'vis_image', |
|
|
out_file + '_0.png'), (h, w, 3)) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
test_cityscapes_add_datasample_forward(gt_sem_seg.cuda()) |
|
|
test_cityscapes_add_datasample_forward(gt_sem_seg) |
|
|
|
|
|
def _assert_image_and_shape(self, out_file, out_shape): |
|
|
assert os.path.exists(out_file) |
|
|
drawn_img = cv2.imread(out_file) |
|
|
assert drawn_img.shape == out_shape |
|
|
|