| | |
| | import copy |
| | from typing import Callable, Dict, List, Optional, Union |
| |
|
| | import numpy as np |
| | from mmcv.transforms import BaseTransform, Compose |
| | from mmcv.transforms.utils import cache_random_params, cache_randomness |
| |
|
| | from mmdet.registry import TRANSFORMS |
| |
|
| |
|
| | @TRANSFORMS.register_module() |
| | class MultiBranch(BaseTransform): |
| | r"""Multiple branch pipeline wrapper. |
| | |
| | Generate multiple data-augmented versions of the same image. |
| | `MultiBranch` needs to specify the branch names of all |
| | pipelines of the dataset, perform corresponding data augmentation |
| | for the current branch, and return None for other branches, |
| | which ensures the consistency of return format across |
| | different samples. |
| | |
| | Args: |
| | branch_field (list): List of branch names. |
| | branch_pipelines (dict): Dict of different pipeline configs |
| | to be composed. |
| | |
| | Examples: |
| | >>> branch_field = ['sup', 'unsup_teacher', 'unsup_student'] |
| | >>> sup_pipeline = [ |
| | >>> dict(type='LoadImageFromFile'), |
| | >>> dict(type='LoadAnnotations', with_bbox=True), |
| | >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
| | >>> dict(type='RandomFlip', prob=0.5), |
| | >>> dict( |
| | >>> type='MultiBranch', |
| | >>> branch_field=branch_field, |
| | >>> sup=dict(type='PackDetInputs')) |
| | >>> ] |
| | >>> weak_pipeline = [ |
| | >>> dict(type='LoadImageFromFile'), |
| | >>> dict(type='LoadAnnotations', with_bbox=True), |
| | >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
| | >>> dict(type='RandomFlip', prob=0.0), |
| | >>> dict( |
| | >>> type='MultiBranch', |
| | >>> branch_field=branch_field, |
| | >>> sup=dict(type='PackDetInputs')) |
| | >>> ] |
| | >>> strong_pipeline = [ |
| | >>> dict(type='LoadImageFromFile'), |
| | >>> dict(type='LoadAnnotations', with_bbox=True), |
| | >>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
| | >>> dict(type='RandomFlip', prob=1.0), |
| | >>> dict( |
| | >>> type='MultiBranch', |
| | >>> branch_field=branch_field, |
| | >>> sup=dict(type='PackDetInputs')) |
| | >>> ] |
| | >>> unsup_pipeline = [ |
| | >>> dict(type='LoadImageFromFile'), |
| | >>> dict(type='LoadEmptyAnnotations'), |
| | >>> dict( |
| | >>> type='MultiBranch', |
| | >>> branch_field=branch_field, |
| | >>> unsup_teacher=weak_pipeline, |
| | >>> unsup_student=strong_pipeline) |
| | >>> ] |
| | >>> from mmcv.transforms import Compose |
| | >>> sup_branch = Compose(sup_pipeline) |
| | >>> unsup_branch = Compose(unsup_pipeline) |
| | >>> print(sup_branch) |
| | >>> Compose( |
| | >>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa |
| | >>> LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2') # noqa |
| | >>> Resize(scale=(1333, 800), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) # noqa |
| | >>> RandomFlip(prob=0.5, direction=horizontal) |
| | >>> MultiBranch(branch_pipelines=['sup']) |
| | >>> ) |
| | >>> print(unsup_branch) |
| | >>> Compose( |
| | >>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa |
| | >>> LoadEmptyAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, seg_ignore_label=255) # noqa |
| | >>> MultiBranch(branch_pipelines=['unsup_teacher', 'unsup_student']) |
| | >>> ) |
| | """ |
| |
|
| | def __init__(self, branch_field: List[str], |
| | **branch_pipelines: dict) -> None: |
| | self.branch_field = branch_field |
| | self.branch_pipelines = { |
| | branch: Compose(pipeline) |
| | for branch, pipeline in branch_pipelines.items() |
| | } |
| |
|
| | def transform(self, results: dict) -> dict: |
| | """Transform function to apply transforms sequentially. |
| | |
| | Args: |
| | results (dict): Result dict from loading pipeline. |
| | |
| | Returns: |
| | dict: |
| | |
| | - 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of |
| | models from different branches. |
| | - 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation |
| | info of the sample from different branches. |
| | """ |
| |
|
| | multi_results = {} |
| | for branch in self.branch_field: |
| | multi_results[branch] = {'inputs': None, 'data_samples': None} |
| | for branch, pipeline in self.branch_pipelines.items(): |
| | branch_results = pipeline(copy.deepcopy(results)) |
| | |
| | |
| | if branch_results is None: |
| | return None |
| | multi_results[branch] = branch_results |
| |
|
| | format_results = {} |
| | for branch, results in multi_results.items(): |
| | for key in results.keys(): |
| | if format_results.get(key, None) is None: |
| | format_results[key] = {branch: results[key]} |
| | else: |
| | format_results[key][branch] = results[key] |
| | return format_results |
| |
|
| | def __repr__(self) -> str: |
| | repr_str = self.__class__.__name__ |
| | repr_str += f'(branch_pipelines={list(self.branch_pipelines.keys())})' |
| | return repr_str |
| |
|
| |
|
| | @TRANSFORMS.register_module() |
| | class RandomOrder(Compose): |
| | """Shuffle the transform Sequence.""" |
| |
|
| | @cache_randomness |
| | def _random_permutation(self): |
| | return np.random.permutation(len(self.transforms)) |
| |
|
| | def transform(self, results: Dict) -> Optional[Dict]: |
| | """Transform function to apply transforms in random order. |
| | |
| | Args: |
| | results (dict): A result dict contains the results to transform. |
| | |
| | Returns: |
| | dict or None: Transformed results. |
| | """ |
| | inds = self._random_permutation() |
| | for idx in inds: |
| | t = self.transforms[idx] |
| | results = t(results) |
| | if results is None: |
| | return None |
| | return results |
| |
|
| | def __repr__(self): |
| | """Compute the string representation.""" |
| | format_string = self.__class__.__name__ + '(' |
| | for t in self.transforms: |
| | format_string += f'{t.__class__.__name__}, ' |
| | format_string += ')' |
| | return format_string |
| |
|
| |
|
| | @TRANSFORMS.register_module() |
| | class ProposalBroadcaster(BaseTransform): |
| | """A transform wrapper to apply the wrapped transforms to process both |
| | `gt_bboxes` and `proposals` without adding any codes. It will do the |
| | following steps: |
| | |
| | 1. Scatter the broadcasting targets to a list of inputs of the wrapped |
| | transforms. The type of the list should be list[dict, dict], which |
| | the first is the original inputs, the second is the processing |
| | results that `gt_bboxes` being rewritten by the `proposals`. |
| | 2. Apply ``self.transforms``, with same random parameters, which is |
| | sharing with a context manager. The type of the outputs is a |
| | list[dict, dict]. |
| | 3. Gather the outputs, update the `proposals` in the first item of |
| | the outputs with the `gt_bboxes` in the second . |
| | |
| | Args: |
| | transforms (list, optional): Sequence of transform |
| | object or config dict to be wrapped. Defaults to []. |
| | |
| | Note: The `TransformBroadcaster` in MMCV can achieve the same operation as |
| | `ProposalBroadcaster`, but need to set more complex parameters. |
| | |
| | Examples: |
| | >>> pipeline = [ |
| | >>> dict(type='LoadImageFromFile'), |
| | >>> dict(type='LoadProposals', num_max_proposals=2000), |
| | >>> dict(type='LoadAnnotations', with_bbox=True), |
| | >>> dict( |
| | >>> type='ProposalBroadcaster', |
| | >>> transforms=[ |
| | >>> dict(type='Resize', scale=(1333, 800), |
| | >>> keep_ratio=True), |
| | >>> dict(type='RandomFlip', prob=0.5), |
| | >>> ]), |
| | >>> dict(type='PackDetInputs')] |
| | """ |
| |
|
| | def __init__(self, transforms: List[Union[dict, Callable]] = []) -> None: |
| | self.transforms = Compose(transforms) |
| |
|
| | def transform(self, results: dict) -> dict: |
| | """Apply wrapped transform functions to process both `gt_bboxes` and |
| | `proposals`. |
| | |
| | Args: |
| | results (dict): Result dict from loading pipeline. |
| | |
| | Returns: |
| | dict: Updated result dict. |
| | """ |
| | assert results.get('proposals', None) is not None, \ |
| | '`proposals` should be in the results, please delete ' \ |
| | '`ProposalBroadcaster` in your configs, or check whether ' \ |
| | 'you have load proposals successfully.' |
| |
|
| | inputs = self._process_input(results) |
| | outputs = self._apply_transforms(inputs) |
| | outputs = self._process_output(outputs) |
| | return outputs |
| |
|
| | def _process_input(self, data: dict) -> list: |
| | """Scatter the broadcasting targets to a list of inputs of the wrapped |
| | transforms. |
| | |
| | Args: |
| | data (dict): The original input data. |
| | |
| | Returns: |
| | list[dict]: A list of input data. |
| | """ |
| | cp_data = copy.deepcopy(data) |
| | cp_data['gt_bboxes'] = cp_data['proposals'] |
| | scatters = [data, cp_data] |
| | return scatters |
| |
|
| | def _apply_transforms(self, inputs: list) -> list: |
| | """Apply ``self.transforms``. |
| | |
| | Args: |
| | inputs (list[dict, dict]): list of input data. |
| | |
| | Returns: |
| | list[dict]: The output of the wrapped pipeline. |
| | """ |
| | assert len(inputs) == 2 |
| | ctx = cache_random_params |
| | with ctx(self.transforms): |
| | output_scatters = [self.transforms(_input) for _input in inputs] |
| | return output_scatters |
| |
|
| | def _process_output(self, output_scatters: list) -> dict: |
| | """Gathering and renaming data items. |
| | |
| | Args: |
| | output_scatters (list[dict, dict]): The output of the wrapped |
| | pipeline. |
| | |
| | Returns: |
| | dict: Updated result dict. |
| | """ |
| | assert isinstance(output_scatters, list) and \ |
| | isinstance(output_scatters[0], dict) and \ |
| | len(output_scatters) == 2 |
| | outputs = output_scatters[0] |
| | outputs['proposals'] = output_scatters[1]['gt_bboxes'] |
| | return outputs |
| |
|