|
|
|
|
|
import copy |
|
|
import os.path as osp |
|
|
import unittest |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
from mmengine.structures import BaseDataElement |
|
|
|
|
|
from mmseg.datasets.transforms import PackSegInputs |
|
|
from mmseg.structures import SegDataSample |
|
|
|
|
|
|
|
|
class TestPackSegInputs(unittest.TestCase): |
|
|
|
|
|
def setUp(self): |
|
|
"""Setup the model and optimizer which are used in every test method. |
|
|
|
|
|
TestCase calls functions in this order: setUp() -> testMethod() -> |
|
|
tearDown() -> cleanUp() |
|
|
""" |
|
|
data_prefix = osp.join(osp.dirname(__file__), '../../data') |
|
|
img_path = osp.join(data_prefix, 'color.jpg') |
|
|
rng = np.random.RandomState(0) |
|
|
self.results = { |
|
|
'img_path': img_path, |
|
|
'ori_shape': (300, 400), |
|
|
'pad_shape': (600, 800), |
|
|
'img_shape': (600, 800), |
|
|
'scale_factor': 2.0, |
|
|
'flip': False, |
|
|
'flip_direction': 'horizontal', |
|
|
'img_norm_cfg': None, |
|
|
'img': rng.rand(300, 400), |
|
|
'gt_seg_map': rng.rand(300, 400), |
|
|
} |
|
|
self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape', |
|
|
'scale_factor', 'flip', 'flip_direction') |
|
|
|
|
|
def test_transform(self): |
|
|
transform = PackSegInputs(meta_keys=self.meta_keys) |
|
|
results = transform(copy.deepcopy(self.results)) |
|
|
self.assertIn('data_samples', results) |
|
|
self.assertIsInstance(results['data_samples'], SegDataSample) |
|
|
self.assertIsInstance(results['data_samples'].gt_sem_seg, |
|
|
BaseDataElement) |
|
|
self.assertEqual(results['data_samples'].ori_shape, |
|
|
results['data_samples'].gt_sem_seg.shape) |
|
|
results = copy.deepcopy(self.results) |
|
|
|
|
|
results['gt_seg_map'] = np.random.rand(3, 300, 400) |
|
|
msg = 'the segmentation map is 2D' |
|
|
with pytest.warns(UserWarning, match=msg): |
|
|
results = transform(results) |
|
|
self.assertEqual(results['data_samples'].ori_shape, |
|
|
results['data_samples'].gt_sem_seg.shape) |
|
|
|
|
|
def test_repr(self): |
|
|
transform = PackSegInputs(meta_keys=self.meta_keys) |
|
|
self.assertEqual( |
|
|
repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})') |
|
|
|