| |
| import os.path as osp |
| from unittest import TestCase |
|
|
| import mmcv |
| import pytest |
|
|
| from mmdet.datasets.transforms import * |
| from mmdet.registry import TRANSFORMS |
|
|
|
|
| class TestMuitiScaleFlipAug(TestCase): |
|
|
| def test_exception(self): |
| with pytest.raises(TypeError): |
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[dict(type='Resize', keep_ratio=False)], |
| ) |
| TRANSFORMS.build(tta_transform) |
|
|
| def test_multi_scale_flip_aug(self): |
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[[ |
| dict(type='Resize', scale=scale, keep_ratio=False) |
| for scale in [(256, 256), (512, 512), (1024, 1024)] |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor')) |
| ]]) |
| tta_module = TRANSFORMS.build(tta_transform) |
|
|
| results = dict() |
| img = mmcv.imread( |
| osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') |
| results['img_id'] = '1' |
| results['img_path'] = 'data/color.jpg' |
| results['img'] = img |
| results['ori_shape'] = img.shape |
| results['ori_height'] = img.shape[0] |
| results['ori_width'] = img.shape[1] |
| |
| results['pad_shape'] = img.shape |
| results['scale_factor'] = 1.0 |
|
|
| tta_results = tta_module(results.copy()) |
| assert [img.shape |
| for img in tta_results['inputs']] == [(3, 256, 256), |
| (3, 512, 512), |
| (3, 1024, 1024)] |
|
|
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[ |
| [ |
| dict(type='Resize', scale=scale, keep_ratio=False) |
| for scale in [(256, 256), (512, 512), (1024, 1024)] |
| ], |
| [ |
| dict(type='RandomFlip', prob=0., direction='horizontal'), |
| dict(type='RandomFlip', prob=1., direction='horizontal') |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor', 'flip', |
| 'flip_direction')) |
| ] |
| ]) |
| tta_module = TRANSFORMS.build(tta_transform) |
| tta_results: dict = tta_module(results.copy()) |
| assert [img.shape |
| for img in tta_results['inputs']] == [(3, 256, 256), |
| (3, 256, 256), |
| (3, 512, 512), |
| (3, 512, 512), |
| (3, 1024, 1024), |
| (3, 1024, 1024)] |
| assert [ |
| data_sample.metainfo['flip'] |
| for data_sample in tta_results['data_samples'] |
| ] == [False, True, False, True, False, True] |
|
|
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[[ |
| dict(type='Resize', scale=(512, 512), keep_ratio=False) |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor')) |
| ]]) |
| tta_module = TRANSFORMS.build(tta_transform) |
| tta_results = tta_module(results.copy()) |
| assert [tta_results['inputs'][0].shape] == [(3, 512, 512)] |
|
|
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[ |
| [dict(type='Resize', scale=(512, 512), keep_ratio=False)], |
| [ |
| dict(type='RandomFlip', prob=0., direction='horizontal'), |
| dict(type='RandomFlip', prob=1., direction='horizontal') |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor', 'flip', |
| 'flip_direction')) |
| ] |
| ]) |
| tta_module = TRANSFORMS.build(tta_transform) |
| tta_results = tta_module(results.copy()) |
| assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512), |
| (3, 512, 512)] |
| assert [ |
| data_sample.metainfo['flip'] |
| for data_sample in tta_results['data_samples'] |
| ] == [False, True] |
|
|
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[[ |
| dict(type='Resize', scale_factor=r, keep_ratio=False) |
| for r in [0.5, 1.0, 2.0] |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor')) |
| ]]) |
| tta_module = TRANSFORMS.build(tta_transform) |
| tta_results = tta_module(results.copy()) |
| assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), |
| (3, 288, 512), |
| (3, 576, 1024)] |
|
|
| tta_transform = dict( |
| type='TestTimeAug', |
| transforms=[ |
| [ |
| dict(type='Resize', scale_factor=r, keep_ratio=True) |
| for r in [0.5, 1.0, 2.0] |
| ], |
| [ |
| dict(type='RandomFlip', prob=0., direction='horizontal'), |
| dict(type='RandomFlip', prob=1., direction='horizontal') |
| ], |
| [ |
| dict( |
| type='mmdet.PackDetInputs', |
| meta_keys=('img_id', 'img_path', 'ori_shape', |
| 'img_shape', 'scale_factor', 'flip', |
| 'flip_direction')) |
| ] |
| ]) |
| tta_module = TRANSFORMS.build(tta_transform) |
| tta_results = tta_module(results.copy()) |
| assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), |
| (3, 144, 256), |
| (3, 288, 512), |
| (3, 288, 512), |
| (3, 576, 1024), |
| (3, 576, 1024)] |
| assert [ |
| data_sample.metainfo['flip'] |
| for data_sample in tta_results['data_samples'] |
| ] == [False, True, False, True, False, True] |
|
|