| | |
| | import copy |
| | import os.path as osp |
| | from typing import Callable, Dict, List, Optional, Sequence, Union |
| |
|
| | import mmengine |
| | import mmengine.fileio as fileio |
| | import numpy as np |
| | from mmengine.dataset import BaseDataset, Compose |
| |
|
| | from mmdet.registry import DATASETS |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class BaseSegDataset(BaseDataset): |
| | """Custom dataset for semantic segmentation. An example of file structure |
| | is as followed. |
| | |
| | .. code-block:: none |
| | |
| | βββ data |
| | β βββ my_dataset |
| | β β βββ img_dir |
| | β β β βββ train |
| | β β β β βββ xxx{img_suffix} |
| | β β β β βββ yyy{img_suffix} |
| | β β β β βββ zzz{img_suffix} |
| | β β β βββ val |
| | β β βββ ann_dir |
| | β β β βββ train |
| | β β β β βββ xxx{seg_map_suffix} |
| | β β β β βββ yyy{seg_map_suffix} |
| | β β β β βββ zzz{seg_map_suffix} |
| | β β β βββ val |
| | |
| | The img/gt_semantic_seg pair of BaseSegDataset should be of the same |
| | except suffix. A valid img/gt_semantic_seg filename pair should be like |
| | ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included |
| | in the suffix). If split is given, then ``xxx`` is specified in txt file. |
| | Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. |
| | Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. |
| | |
| | |
| | Args: |
| | ann_file (str): Annotation file path. Defaults to ''. |
| | metainfo (dict, optional): Meta information for dataset, such as |
| | specify classes to load. Defaults to None. |
| | data_root (str, optional): The root directory for ``data_prefix`` and |
| | ``ann_file``. Defaults to None. |
| | data_prefix (dict, optional): Prefix for training data. Defaults to |
| | dict(img_path=None, seg_map_path=None). |
| | img_suffix (str): Suffix of images. Default: '.jpg' |
| | seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' |
| | filter_cfg (dict, optional): Config for filter data. Defaults to None. |
| | indices (int or Sequence[int], optional): Support using first few |
| | data in annotation file to facilitate training/testing on a smaller |
| | dataset. Defaults to None which means using all ``data_infos``. |
| | serialize_data (bool, optional): Whether to hold memory using |
| | serialized objects, when enabled, data loader workers can use |
| | shared RAM from master process instead of making a copy. Defaults |
| | to True. |
| | pipeline (list, optional): Processing pipeline. Defaults to []. |
| | test_mode (bool, optional): ``test_mode=True`` means in test phase. |
| | Defaults to False. |
| | lazy_init (bool, optional): Whether to load annotation during |
| | instantiation. In some cases, such as visualization, only the meta |
| | information of the dataset is needed, which is not necessary to |
| | load annotation file. ``Basedataset`` can skip load annotations to |
| | save time by set ``lazy_init=True``. Defaults to False. |
| | use_label_map (bool, optional): Whether to use label map. |
| | Defaults to False. |
| | max_refetch (int, optional): If ``Basedataset.prepare_data`` get a |
| | None img. The maximum extra number of cycles to get a valid |
| | image. Defaults to 1000. |
| | backend_args (dict, Optional): Arguments to instantiate a file backend. |
| | See https://mmengine.readthedocs.io/en/latest/api/fileio.htm |
| | for details. Defaults to None. |
| | Notes: mmcv>=2.0.0rc4 required. |
| | """ |
| | METAINFO: dict = dict() |
| |
|
| | def __init__(self, |
| | ann_file: str = '', |
| | img_suffix='.jpg', |
| | seg_map_suffix='.png', |
| | metainfo: Optional[dict] = None, |
| | data_root: Optional[str] = None, |
| | data_prefix: dict = dict(img_path='', seg_map_path=''), |
| | filter_cfg: Optional[dict] = None, |
| | indices: Optional[Union[int, Sequence[int]]] = None, |
| | serialize_data: bool = True, |
| | pipeline: List[Union[dict, Callable]] = [], |
| | test_mode: bool = False, |
| | lazy_init: bool = False, |
| | use_label_map: bool = False, |
| | max_refetch: int = 1000, |
| | backend_args: Optional[dict] = None) -> None: |
| |
|
| | self.img_suffix = img_suffix |
| | self.seg_map_suffix = seg_map_suffix |
| | self.backend_args = backend_args.copy() if backend_args else None |
| |
|
| | self.data_root = data_root |
| | self.data_prefix = copy.copy(data_prefix) |
| | self.ann_file = ann_file |
| | self.filter_cfg = copy.deepcopy(filter_cfg) |
| | self._indices = indices |
| | self.serialize_data = serialize_data |
| | self.test_mode = test_mode |
| | self.max_refetch = max_refetch |
| | self.data_list: List[dict] = [] |
| | self.data_bytes: np.ndarray |
| |
|
| | |
| | self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) |
| |
|
| | |
| | new_classes = self._metainfo.get('classes', None) |
| | self.label_map = self.get_label_map( |
| | new_classes) if use_label_map else None |
| | self._metainfo.update(dict(label_map=self.label_map)) |
| |
|
| | |
| | |
| | updated_palette = self._update_palette() |
| | self._metainfo.update(dict(palette=updated_palette)) |
| |
|
| | |
| | if self.data_root is not None: |
| | self._join_prefix() |
| |
|
| | |
| | self.pipeline = Compose(pipeline) |
| | |
| | if not lazy_init: |
| | self.full_init() |
| |
|
| | if test_mode: |
| | assert self._metainfo.get('classes') is not None, \ |
| | 'dataset metainfo `classes` should be specified when testing' |
| |
|
| | @classmethod |
| | def get_label_map(cls, |
| | new_classes: Optional[Sequence] = None |
| | ) -> Union[Dict, None]: |
| | """Require label mapping. |
| | |
| | The ``label_map`` is a dictionary, its keys are the old label ids and |
| | its values are the new label ids, and is used for changing pixel |
| | labels in load_annotations. If and only if old classes in cls.METAINFO |
| | is not equal to new classes in self._metainfo and nether of them is not |
| | None, `label_map` is not None. |
| | |
| | Args: |
| | new_classes (list, tuple, optional): The new classes name from |
| | metainfo. Default to None. |
| | |
| | |
| | Returns: |
| | dict, optional: The mapping from old classes in cls.METAINFO to |
| | new classes in self._metainfo |
| | """ |
| | old_classes = cls.METAINFO.get('classes', None) |
| | if (new_classes is not None and old_classes is not None |
| | and list(new_classes) != list(old_classes)): |
| |
|
| | label_map = {} |
| | if not set(new_classes).issubset(cls.METAINFO['classes']): |
| | raise ValueError( |
| | f'new classes {new_classes} is not a ' |
| | f'subset of classes {old_classes} in METAINFO.') |
| | for i, c in enumerate(old_classes): |
| | if c not in new_classes: |
| | |
| | label_map[i] = 0 |
| | else: |
| | label_map[i] = new_classes.index(c) |
| | return label_map |
| | else: |
| | return None |
| |
|
| | def _update_palette(self) -> list: |
| | """Update palette after loading metainfo. |
| | |
| | If length of palette is equal to classes, just return the palette. |
| | If palette is not defined, it will randomly generate a palette. |
| | If classes is updated by customer, it will return the subset of |
| | palette. |
| | |
| | Returns: |
| | Sequence: Palette for current dataset. |
| | """ |
| | palette = self._metainfo.get('palette', []) |
| | classes = self._metainfo.get('classes', []) |
| | |
| | if len(palette) == len(classes): |
| | return palette |
| |
|
| | if len(palette) == 0: |
| | |
| | |
| | |
| | |
| | |
| | state = np.random.get_state() |
| | np.random.seed(42) |
| | |
| | new_palette = np.random.randint( |
| | 0, 255, size=(len(classes), 3)).tolist() |
| | np.random.set_state(state) |
| | elif len(palette) >= len(classes) and self.label_map is not None: |
| | new_palette = [] |
| | |
| | for old_id, new_id in sorted( |
| | self.label_map.items(), key=lambda x: x[1]): |
| | |
| | if new_id != 0: |
| | new_palette.append(palette[old_id]) |
| | new_palette = type(palette)(new_palette) |
| | elif len(palette) >= len(classes): |
| | |
| | return palette |
| | else: |
| | raise ValueError('palette does not match classes ' |
| | f'as metainfo is {self._metainfo}.') |
| | return new_palette |
| |
|
| | def load_data_list(self) -> List[dict]: |
| | """Load annotation from directory or annotation file. |
| | |
| | Returns: |
| | list[dict]: All data info of dataset. |
| | """ |
| | data_list = [] |
| | img_dir = self.data_prefix.get('img_path', None) |
| | ann_dir = self.data_prefix.get('seg_map_path', None) |
| | if not osp.isdir(self.ann_file) and self.ann_file: |
| | assert osp.isfile(self.ann_file), \ |
| | f'Failed to load `ann_file` {self.ann_file}' |
| | lines = mmengine.list_from_file( |
| | self.ann_file, backend_args=self.backend_args) |
| | for line in lines: |
| | img_name = line.strip() |
| | data_info = dict( |
| | img_path=osp.join(img_dir, img_name + self.img_suffix)) |
| | if ann_dir is not None: |
| | seg_map = img_name + self.seg_map_suffix |
| | data_info['seg_map_path'] = osp.join(ann_dir, seg_map) |
| | data_info['label_map'] = self.label_map |
| | data_list.append(data_info) |
| | else: |
| | for img in fileio.list_dir_or_file( |
| | dir_path=img_dir, |
| | list_dir=False, |
| | suffix=self.img_suffix, |
| | recursive=True, |
| | backend_args=self.backend_args): |
| | data_info = dict(img_path=osp.join(img_dir, img)) |
| | if ann_dir is not None: |
| | seg_map = img.replace(self.img_suffix, self.seg_map_suffix) |
| | data_info['seg_map_path'] = osp.join(ann_dir, seg_map) |
| | data_info['label_map'] = self.label_map |
| | data_list.append(data_info) |
| | data_list = sorted(data_list, key=lambda x: x['img_path']) |
| | return data_list |
| |
|