Spaces:
Runtime error
Runtime error
| import warnings | |
| import mmcv | |
| from ..builder import PIPELINES | |
| from .compose import Compose | |
| class MultiScaleFlipAug(object): | |
| """Test-time augmentation with multiple scales and flipping. | |
| An example configuration is as followed: | |
| .. code-block:: | |
| img_scale=(2048, 1024), | |
| img_ratios=[0.5, 1.0], | |
| flip=True, | |
| transforms=[ | |
| dict(type='Resize', keep_ratio=True), | |
| dict(type='RandomFlip'), | |
| dict(type='Normalize', **img_norm_cfg), | |
| dict(type='Pad', size_divisor=32), | |
| dict(type='ImageToTensor', keys=['img']), | |
| dict(type='Collect', keys=['img']), | |
| ] | |
| After MultiScaleFLipAug with above configuration, the results are wrapped | |
| into lists of the same length as followed: | |
| .. code-block:: | |
| dict( | |
| img=[...], | |
| img_shape=[...], | |
| scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] | |
| flip=[False, True, False, True] | |
| ... | |
| ) | |
| Args: | |
| transforms (list[dict]): Transforms to apply in each augmentation. | |
| img_scale (None | tuple | list[tuple]): Images scales for resizing. | |
| img_ratios (float | list[float]): Image ratios for resizing | |
| flip (bool): Whether apply flip augmentation. Default: False. | |
| flip_direction (str | list[str]): Flip augmentation directions, | |
| options are "horizontal" and "vertical". If flip_direction is list, | |
| multiple flip augmentations will be applied. | |
| It has no effect when flip == False. Default: "horizontal". | |
| """ | |
| def __init__(self, | |
| transforms, | |
| img_scale, | |
| img_ratios=None, | |
| flip=False, | |
| flip_direction='horizontal'): | |
| self.transforms = Compose(transforms) | |
| if img_ratios is not None: | |
| img_ratios = img_ratios if isinstance(img_ratios, | |
| list) else [img_ratios] | |
| assert mmcv.is_list_of(img_ratios, float) | |
| if img_scale is None: | |
| # mode 1: given img_scale=None and a range of image ratio | |
| self.img_scale = None | |
| assert mmcv.is_list_of(img_ratios, float) | |
| elif isinstance(img_scale, tuple) and mmcv.is_list_of( | |
| img_ratios, float): | |
| assert len(img_scale) == 2 | |
| # mode 2: given a scale and a range of image ratio | |
| self.img_scale = [(int(img_scale[0] * ratio), | |
| int(img_scale[1] * ratio)) | |
| for ratio in img_ratios] | |
| else: | |
| # mode 3: given multiple scales | |
| self.img_scale = img_scale if isinstance(img_scale, | |
| list) else [img_scale] | |
| assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None | |
| self.flip = flip | |
| self.img_ratios = img_ratios | |
| self.flip_direction = flip_direction if isinstance( | |
| flip_direction, list) else [flip_direction] | |
| assert mmcv.is_list_of(self.flip_direction, str) | |
| if not self.flip and self.flip_direction != ['horizontal']: | |
| warnings.warn( | |
| 'flip_direction has no effect when flip is set to False') | |
| if (self.flip | |
| and not any([t['type'] == 'RandomFlip' for t in transforms])): | |
| warnings.warn( | |
| 'flip has no effect when RandomFlip is not in transforms') | |
| def __call__(self, results): | |
| """Call function to apply test time augment transforms on results. | |
| Args: | |
| results (dict): Result dict contains the data to transform. | |
| Returns: | |
| dict[str: list]: The augmented data, where each value is wrapped | |
| into a list. | |
| """ | |
| aug_data = [] | |
| if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): | |
| h, w = results['img'].shape[:2] | |
| img_scale = [(int(w * ratio), int(h * ratio)) | |
| for ratio in self.img_ratios] | |
| else: | |
| img_scale = self.img_scale | |
| flip_aug = [False, True] if self.flip else [False] | |
| for scale in img_scale: | |
| for flip in flip_aug: | |
| for direction in self.flip_direction: | |
| _results = results.copy() | |
| _results['scale'] = scale | |
| _results['flip'] = flip | |
| _results['flip_direction'] = direction | |
| data = self.transforms(_results) | |
| aug_data.append(data) | |
| # list of dict to dict of list | |
| aug_data_dict = {key: [] for key in aug_data[0]} | |
| for data in aug_data: | |
| for key, val in data.items(): | |
| aug_data_dict[key].append(val) | |
| return aug_data_dict | |
| def __repr__(self): | |
| repr_str = self.__class__.__name__ | |
| repr_str += f'(transforms={self.transforms}, ' | |
| repr_str += f'img_scale={self.img_scale}, flip={self.flip})' | |
| repr_str += f'flip_direction={self.flip_direction}' | |
| return repr_str | |