| | |
| | import os |
| | from typing import Dict, List, Optional, Sequence, Union |
| |
|
| | from mmseg.registry import DATASETS |
| | from .basesegdataset import BaseSegDataset |
| |
|
| | try: |
| | from dsdl.dataset import DSDLDataset |
| | except ImportError: |
| | DSDLDataset = None |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class DSDLSegDataset(BaseSegDataset): |
| | """Dataset for dsdl segmentation. |
| | |
| | Args: |
| | specific_key_path(dict): Path of specific key which can not |
| | be loaded by it's field name. |
| | pre_transform(dict): pre-transform functions before loading. |
| | used_labels(sequence): list of actual used classes in train steps, |
| | this must be subset of class domain. |
| | """ |
| |
|
| | METAINFO = {} |
| |
|
| | def __init__(self, |
| | specific_key_path: Dict = {}, |
| | pre_transform: Dict = {}, |
| | used_labels: Optional[Sequence] = None, |
| | **kwargs) -> None: |
| |
|
| | if DSDLDataset is None: |
| | raise RuntimeError( |
| | 'Package dsdl is not installed. Please run "pip install dsdl".' |
| | ) |
| | self.used_labels = used_labels |
| |
|
| | loc_config = dict(type='LocalFileReader', working_dir='') |
| | if kwargs.get('data_root'): |
| | kwargs['ann_file'] = os.path.join(kwargs['data_root'], |
| | kwargs['ann_file']) |
| | required_fields = ['Image', 'LabelMap'] |
| |
|
| | self.dsdldataset = DSDLDataset( |
| | dsdl_yaml=kwargs['ann_file'], |
| | location_config=loc_config, |
| | required_fields=required_fields, |
| | specific_key_path=specific_key_path, |
| | transform=pre_transform, |
| | ) |
| | BaseSegDataset.__init__(self, **kwargs) |
| |
|
| | def load_data_list(self) -> List[Dict]: |
| | """Load data info from a dsdl yaml file named as ``self.ann_file`` |
| | |
| | Returns: |
| | List[dict]: A list of data list. |
| | """ |
| |
|
| | if self.used_labels: |
| | self._metainfo['classes'] = tuple(self.used_labels) |
| | self.label_map = self.get_label_map(self.used_labels) |
| | else: |
| | self._metainfo['classes'] = tuple(['background'] + |
| | self.dsdldataset.class_names) |
| | data_list = [] |
| |
|
| | for i, data in enumerate(self.dsdldataset): |
| | datainfo = dict( |
| | img_path=os.path.join(self.data_prefix['img_path'], |
| | data['Image'][0].location), |
| | seg_map_path=os.path.join(self.data_prefix['seg_map_path'], |
| | data['LabelMap'][0].location), |
| | label_map=self.label_map, |
| | reduce_zero_label=self.reduce_zero_label, |
| | seg_fields=[], |
| | ) |
| | data_list.append(datainfo) |
| |
|
| | return data_list |
| |
|
| | def get_label_map(self, |
| | new_classes: Optional[Sequence] = None |
| | ) -> Union[Dict, None]: |
| | """Require label mapping. |
| | |
| | The ``label_map`` is a dictionary, its keys are the old label ids and |
| | its values are the new label ids, and is used for changing pixel |
| | labels in load_annotations. If and only if old classes in class_dom |
| | is not equal to new classes in args and nether of them is not |
| | None, `label_map` is not None. |
| | Args: |
| | new_classes (list, tuple, optional): The new classes name from |
| | metainfo. Default to None. |
| | Returns: |
| | dict, optional: The mapping from old classes to new classes. |
| | """ |
| | old_classes = ['background'] + self.dsdldataset.class_names |
| | if (new_classes is not None and old_classes is not None |
| | and list(new_classes) != list(old_classes)): |
| |
|
| | label_map = {} |
| | if not set(new_classes).issubset(old_classes): |
| | raise ValueError( |
| | f'new classes {new_classes} is not a ' |
| | f'subset of classes {old_classes} in class_dom.') |
| | for i, c in enumerate(old_classes): |
| | if c not in new_classes: |
| | label_map[i] = 255 |
| | else: |
| | label_map[i] = new_classes.index(c) |
| | return label_map |
| | else: |
| | return None |
| |
|