Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import bisect | |
| import collections | |
| import copy | |
| import math | |
| from collections import defaultdict | |
| import numpy as np | |
| from mmcv.utils import build_from_cfg, print_log | |
| from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
| from .builder import DATASETS, PIPELINES | |
| from .coco import CocoDataset | |
| class ConcatDataset(_ConcatDataset): | |
| """A wrapper of concatenated dataset. | |
| Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but | |
| concat the group flag for image aspect ratio. | |
| Args: | |
| datasets (list[:obj:`Dataset`]): A list of datasets. | |
| separate_eval (bool): Whether to evaluate the results | |
| separately if it is used as validation dataset. | |
| Defaults to True. | |
| """ | |
| def __init__(self, datasets, separate_eval=True): | |
| super(ConcatDataset, self).__init__(datasets) | |
| self.CLASSES = datasets[0].CLASSES | |
| self.PALETTE = getattr(datasets[0], 'PALETTE', None) | |
| self.separate_eval = separate_eval | |
| if not separate_eval: | |
| if any([isinstance(ds, CocoDataset) for ds in datasets]): | |
| raise NotImplementedError( | |
| 'Evaluating concatenated CocoDataset as a whole is not' | |
| ' supported! Please set "separate_eval=True"') | |
| elif len(set([type(ds) for ds in datasets])) != 1: | |
| raise NotImplementedError( | |
| 'All the datasets should have same types') | |
| if hasattr(datasets[0], 'flag'): | |
| flags = [] | |
| for i in range(0, len(datasets)): | |
| flags.append(datasets[i].flag) | |
| self.flag = np.concatenate(flags) | |
| def get_cat_ids(self, idx): | |
| """Get category ids of concatenated dataset by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| list[int]: All categories in the image of specified index. | |
| """ | |
| if idx < 0: | |
| if -idx > len(self): | |
| raise ValueError( | |
| 'absolute value of index should not exceed dataset length') | |
| idx = len(self) + idx | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return self.datasets[dataset_idx].get_cat_ids(sample_idx) | |
| def get_ann_info(self, idx): | |
| """Get annotation of concatenated dataset by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| dict: Annotation info of specified index. | |
| """ | |
| if idx < 0: | |
| if -idx > len(self): | |
| raise ValueError( | |
| 'absolute value of index should not exceed dataset length') | |
| idx = len(self) + idx | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return self.datasets[dataset_idx].get_ann_info(sample_idx) | |
| def evaluate(self, results, logger=None, **kwargs): | |
| """Evaluate the results. | |
| Args: | |
| results (list[list | tuple]): Testing results of the dataset. | |
| logger (logging.Logger | str | None): Logger used for printing | |
| related information during evaluation. Default: None. | |
| Returns: | |
| dict[str: float]: AP results of the total dataset or each separate | |
| dataset if `self.separate_eval=True`. | |
| """ | |
| assert len(results) == self.cumulative_sizes[-1], \ | |
| ('Dataset and results have different sizes: ' | |
| f'{self.cumulative_sizes[-1]} v.s. {len(results)}') | |
| # Check whether all the datasets support evaluation | |
| for dataset in self.datasets: | |
| assert hasattr(dataset, 'evaluate'), \ | |
| f'{type(dataset)} does not implement evaluate function' | |
| if self.separate_eval: | |
| dataset_idx = -1 | |
| total_eval_results = dict() | |
| for size, dataset in zip(self.cumulative_sizes, self.datasets): | |
| start_idx = 0 if dataset_idx == -1 else \ | |
| self.cumulative_sizes[dataset_idx] | |
| end_idx = self.cumulative_sizes[dataset_idx + 1] | |
| results_per_dataset = results[start_idx:end_idx] | |
| print_log( | |
| f'\nEvaluating {dataset.ann_file} with ' | |
| f'{len(results_per_dataset)} images now', | |
| logger=logger) | |
| eval_results_per_dataset = dataset.evaluate( | |
| results_per_dataset, logger=logger, **kwargs) | |
| dataset_idx += 1 | |
| for k, v in eval_results_per_dataset.items(): | |
| total_eval_results.update({f'{dataset_idx}_{k}': v}) | |
| return total_eval_results | |
| elif any([isinstance(ds, CocoDataset) for ds in self.datasets]): | |
| raise NotImplementedError( | |
| 'Evaluating concatenated CocoDataset as a whole is not' | |
| ' supported! Please set "separate_eval=True"') | |
| elif len(set([type(ds) for ds in self.datasets])) != 1: | |
| raise NotImplementedError( | |
| 'All the datasets should have same types') | |
| else: | |
| original_data_infos = self.datasets[0].data_infos | |
| self.datasets[0].data_infos = sum( | |
| [dataset.data_infos for dataset in self.datasets], []) | |
| eval_results = self.datasets[0].evaluate( | |
| results, logger=logger, **kwargs) | |
| self.datasets[0].data_infos = original_data_infos | |
| return eval_results | |
| class RepeatDataset: | |
| """A wrapper of repeated dataset. | |
| The length of repeated dataset will be `times` larger than the original | |
| dataset. This is useful when the data loading time is long but the dataset | |
| is small. Using RepeatDataset can reduce the data loading time between | |
| epochs. | |
| Args: | |
| dataset (:obj:`Dataset`): The dataset to be repeated. | |
| times (int): Repeat times. | |
| """ | |
| def __init__(self, dataset, times): | |
| self.dataset = dataset | |
| self.times = times | |
| self.CLASSES = dataset.CLASSES | |
| self.PALETTE = getattr(dataset, 'PALETTE', None) | |
| if hasattr(self.dataset, 'flag'): | |
| self.flag = np.tile(self.dataset.flag, times) | |
| self._ori_len = len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx % self._ori_len] | |
| def get_cat_ids(self, idx): | |
| """Get category ids of repeat dataset by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| list[int]: All categories in the image of specified index. | |
| """ | |
| return self.dataset.get_cat_ids(idx % self._ori_len) | |
| def get_ann_info(self, idx): | |
| """Get annotation of repeat dataset by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| dict: Annotation info of specified index. | |
| """ | |
| return self.dataset.get_ann_info(idx % self._ori_len) | |
| def __len__(self): | |
| """Length after repetition.""" | |
| return self.times * self._ori_len | |
| # Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa | |
| class ClassBalancedDataset: | |
| """A wrapper of repeated dataset with repeat factor. | |
| Suitable for training on class imbalanced datasets like LVIS. Following | |
| the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_, | |
| in each epoch, an image may appear multiple times based on its | |
| "repeat factor". | |
| The repeat factor for an image is a function of the frequency the rarest | |
| category labeled in that image. The "frequency of category c" in [0, 1] | |
| is defined by the fraction of images in the training set (without repeats) | |
| in which category c appears. | |
| The dataset needs to instantiate :func:`self.get_cat_ids` to support | |
| ClassBalancedDataset. | |
| The repeat factor is computed as followed. | |
| 1. For each category c, compute the fraction # of images | |
| that contain it: :math:`f(c)` | |
| 2. For each category c, compute the category-level repeat factor: | |
| :math:`r(c) = max(1, sqrt(t/f(c)))` | |
| 3. For each image I, compute the image-level repeat factor: | |
| :math:`r(I) = max_{c in I} r(c)` | |
| Args: | |
| dataset (:obj:`CustomDataset`): The dataset to be repeated. | |
| oversample_thr (float): frequency threshold below which data is | |
| repeated. For categories with ``f_c >= oversample_thr``, there is | |
| no oversampling. For categories with ``f_c < oversample_thr``, the | |
| degree of oversampling following the square-root inverse frequency | |
| heuristic above. | |
| filter_empty_gt (bool, optional): If set true, images without bounding | |
| boxes will not be oversampled. Otherwise, they will be categorized | |
| as the pure background class and involved into the oversampling. | |
| Default: True. | |
| """ | |
| def __init__(self, dataset, oversample_thr, filter_empty_gt=True): | |
| self.dataset = dataset | |
| self.oversample_thr = oversample_thr | |
| self.filter_empty_gt = filter_empty_gt | |
| self.CLASSES = dataset.CLASSES | |
| self.PALETTE = getattr(dataset, 'PALETTE', None) | |
| repeat_factors = self._get_repeat_factors(dataset, oversample_thr) | |
| repeat_indices = [] | |
| for dataset_idx, repeat_factor in enumerate(repeat_factors): | |
| repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor)) | |
| self.repeat_indices = repeat_indices | |
| flags = [] | |
| if hasattr(self.dataset, 'flag'): | |
| for flag, repeat_factor in zip(self.dataset.flag, repeat_factors): | |
| flags.extend([flag] * int(math.ceil(repeat_factor))) | |
| assert len(flags) == len(repeat_indices) | |
| self.flag = np.asarray(flags, dtype=np.uint8) | |
| def _get_repeat_factors(self, dataset, repeat_thr): | |
| """Get repeat factor for each images in the dataset. | |
| Args: | |
| dataset (:obj:`CustomDataset`): The dataset | |
| repeat_thr (float): The threshold of frequency. If an image | |
| contains the categories whose frequency below the threshold, | |
| it would be repeated. | |
| Returns: | |
| list[float]: The repeat factors for each images in the dataset. | |
| """ | |
| # 1. For each category c, compute the fraction # of images | |
| # that contain it: f(c) | |
| category_freq = defaultdict(int) | |
| num_images = len(dataset) | |
| for idx in range(num_images): | |
| cat_ids = set(self.dataset.get_cat_ids(idx)) | |
| if len(cat_ids) == 0 and not self.filter_empty_gt: | |
| cat_ids = set([len(self.CLASSES)]) | |
| for cat_id in cat_ids: | |
| category_freq[cat_id] += 1 | |
| for k, v in category_freq.items(): | |
| category_freq[k] = v / num_images | |
| # 2. For each category c, compute the category-level repeat factor: | |
| # r(c) = max(1, sqrt(t/f(c))) | |
| category_repeat = { | |
| cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) | |
| for cat_id, cat_freq in category_freq.items() | |
| } | |
| # 3. For each image I, compute the image-level repeat factor: | |
| # r(I) = max_{c in I} r(c) | |
| repeat_factors = [] | |
| for idx in range(num_images): | |
| cat_ids = set(self.dataset.get_cat_ids(idx)) | |
| if len(cat_ids) == 0 and not self.filter_empty_gt: | |
| cat_ids = set([len(self.CLASSES)]) | |
| repeat_factor = 1 | |
| if len(cat_ids) > 0: | |
| repeat_factor = max( | |
| {category_repeat[cat_id] | |
| for cat_id in cat_ids}) | |
| repeat_factors.append(repeat_factor) | |
| return repeat_factors | |
| def __getitem__(self, idx): | |
| ori_index = self.repeat_indices[idx] | |
| return self.dataset[ori_index] | |
| def get_ann_info(self, idx): | |
| """Get annotation of dataset by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| dict: Annotation info of specified index. | |
| """ | |
| ori_index = self.repeat_indices[idx] | |
| return self.dataset.get_ann_info(ori_index) | |
| def __len__(self): | |
| """Length after repetition.""" | |
| return len(self.repeat_indices) | |
| class MultiImageMixDataset: | |
| """A wrapper of multiple images mixed dataset. | |
| Suitable for training on multiple images mixed data augmentation like | |
| mosaic and mixup. For the augmentation pipeline of mixed image data, | |
| the `get_indexes` method needs to be provided to obtain the image | |
| indexes, and you can set `skip_flags` to change the pipeline running | |
| process. At the same time, we provide the `dynamic_scale` parameter | |
| to dynamically change the output image size. | |
| Args: | |
| dataset (:obj:`CustomDataset`): The dataset to be mixed. | |
| pipeline (Sequence[dict]): Sequence of transform object or | |
| config dict to be composed. | |
| dynamic_scale (tuple[int], optional): The image scale can be changed | |
| dynamically. Default to None. It is deprecated. | |
| skip_type_keys (list[str], optional): Sequence of type string to | |
| be skip pipeline. Default to None. | |
| max_refetch (int): The maximum number of retry iterations for getting | |
| valid results from the pipeline. If the number of iterations is | |
| greater than `max_refetch`, but results is still None, then the | |
| iteration is terminated and raise the error. Default: 15. | |
| """ | |
| def __init__(self, | |
| dataset, | |
| pipeline, | |
| dynamic_scale=None, | |
| skip_type_keys=None, | |
| max_refetch=15): | |
| if dynamic_scale is not None: | |
| raise RuntimeError( | |
| 'dynamic_scale is deprecated. Please use Resize pipeline ' | |
| 'to achieve similar functions') | |
| assert isinstance(pipeline, collections.abc.Sequence) | |
| if skip_type_keys is not None: | |
| assert all([ | |
| isinstance(skip_type_key, str) | |
| for skip_type_key in skip_type_keys | |
| ]) | |
| self._skip_type_keys = skip_type_keys | |
| self.pipeline = [] | |
| self.pipeline_types = [] | |
| for transform in pipeline: | |
| if isinstance(transform, dict): | |
| self.pipeline_types.append(transform['type']) | |
| transform = build_from_cfg(transform, PIPELINES) | |
| self.pipeline.append(transform) | |
| else: | |
| raise TypeError('pipeline must be a dict') | |
| self.dataset = dataset | |
| self.CLASSES = dataset.CLASSES | |
| self.PALETTE = getattr(dataset, 'PALETTE', None) | |
| if hasattr(self.dataset, 'flag'): | |
| self.flag = dataset.flag | |
| self.num_samples = len(dataset) | |
| self.max_refetch = max_refetch | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, idx): | |
| results = copy.deepcopy(self.dataset[idx]) | |
| for (transform, transform_type) in zip(self.pipeline, | |
| self.pipeline_types): | |
| if self._skip_type_keys is not None and \ | |
| transform_type in self._skip_type_keys: | |
| continue | |
| if hasattr(transform, 'get_indexes'): | |
| for i in range(self.max_refetch): | |
| # Make sure the results passed the loading pipeline | |
| # of the original dataset is not None. | |
| indexes = transform.get_indexes(self.dataset) | |
| if not isinstance(indexes, collections.abc.Sequence): | |
| indexes = [indexes] | |
| mix_results = [ | |
| copy.deepcopy(self.dataset[index]) for index in indexes | |
| ] | |
| if None not in mix_results: | |
| results['mix_results'] = mix_results | |
| break | |
| else: | |
| raise RuntimeError( | |
| 'The loading pipeline of the original dataset' | |
| ' always return None. Please check the correctness ' | |
| 'of the dataset and its pipeline.') | |
| for i in range(self.max_refetch): | |
| # To confirm the results passed the training pipeline | |
| # of the wrapper is not None. | |
| updated_results = transform(copy.deepcopy(results)) | |
| if updated_results is not None: | |
| results = updated_results | |
| break | |
| else: | |
| raise RuntimeError( | |
| 'The training pipeline of the dataset wrapper' | |
| ' always return None.Please check the correctness ' | |
| 'of the dataset and its pipeline.') | |
| if 'mix_results' in results: | |
| results.pop('mix_results') | |
| return results | |
| def update_skip_type_keys(self, skip_type_keys): | |
| """Update skip_type_keys. It is called by an external hook. | |
| Args: | |
| skip_type_keys (list[str], optional): Sequence of type | |
| string to be skip pipeline. | |
| """ | |
| assert all([ | |
| isinstance(skip_type_key, str) for skip_type_key in skip_type_keys | |
| ]) | |
| self._skip_type_keys = skip_type_keys | |