| |
| import copy |
|
|
| from mmdet.datasets import DATASETS, ConcatDataset, build_dataset |
|
|
| from mmocr.utils import is_2dlist, is_type_list |
|
|
|
|
| @DATASETS.register_module() |
| class UniformConcatDataset(ConcatDataset): |
| """A wrapper of ConcatDataset which support dataset pipeline assignment and |
| replacement. |
| |
| Args: |
| datasets (list[dict] | list[list[dict]]): A list of datasets cfgs. |
| separate_eval (bool): Whether to evaluate the results |
| separately if it is used as validation dataset. |
| Defaults to True. |
| pipeline (None | list[dict] | list[list[dict]]): If ``None``, |
| each dataset in datasets use its own pipeline; |
| If ``list[dict]``, it will be assigned to the dataset whose |
| pipeline is None in datasets; |
| If ``list[list[dict]]``, pipeline of dataset which is None |
| in datasets will be replaced by the corresponding pipeline |
| in the list. |
| force_apply (bool): If True, apply pipeline above to each dataset |
| even if it have its own pipeline. Default: False. |
| """ |
|
|
| def __init__(self, |
| datasets, |
| separate_eval=True, |
| pipeline=None, |
| force_apply=False, |
| **kwargs): |
| new_datasets = [] |
| if pipeline is not None: |
| assert isinstance( |
| pipeline, |
| list), 'pipeline must be list[dict] or list[list[dict]].' |
| if is_type_list(pipeline, dict): |
| self._apply_pipeline(datasets, pipeline, force_apply) |
| new_datasets = datasets |
| elif is_2dlist(pipeline): |
| assert is_2dlist(datasets) |
| assert len(datasets) == len(pipeline) |
| for sub_datasets, tmp_pipeline in zip(datasets, pipeline): |
| self._apply_pipeline(sub_datasets, tmp_pipeline, |
| force_apply) |
| new_datasets.extend(sub_datasets) |
| else: |
| if is_2dlist(datasets): |
| for sub_datasets in datasets: |
| new_datasets.extend(sub_datasets) |
| else: |
| new_datasets = datasets |
| datasets = [build_dataset(c, kwargs) for c in new_datasets] |
| super().__init__(datasets, separate_eval) |
|
|
| @staticmethod |
| def _apply_pipeline(datasets, pipeline, force_apply=False): |
| from_cfg = all(isinstance(x, dict) for x in datasets) |
| assert from_cfg, 'datasets should be config dicts' |
| assert all(isinstance(x, dict) for x in pipeline) |
| for dataset in datasets: |
| if dataset['pipeline'] is None or force_apply: |
| dataset['pipeline'] = copy.deepcopy(pipeline) |
|
|