| | import os |
| | import os.path as osp |
| | from functools import reduce |
| |
|
| | import mmcv |
| | import numpy as np |
| | from mmcv.utils import print_log |
| | from terminaltables import AsciiTable |
| | from torch.utils.data import Dataset |
| |
|
| | from mmseg.core import eval_metrics |
| | from mmseg.utils import get_root_logger |
| | from .builder import DATASETS |
| | from .pipelines import Compose |
| |
|
| |
|
| | @DATASETS.register_module() |
| | class CustomDataset(Dataset): |
| | """Custom dataset for semantic segmentation. An example of file structure |
| | is as followed. |
| | |
| | .. code-block:: none |
| | |
| | βββ data |
| | β βββ my_dataset |
| | β β βββ img_dir |
| | β β β βββ train |
| | β β β β βββ xxx{img_suffix} |
| | β β β β βββ yyy{img_suffix} |
| | β β β β βββ zzz{img_suffix} |
| | β β β βββ val |
| | β β βββ ann_dir |
| | β β β βββ train |
| | β β β β βββ xxx{seg_map_suffix} |
| | β β β β βββ yyy{seg_map_suffix} |
| | β β β β βββ zzz{seg_map_suffix} |
| | β β β βββ val |
| | |
| | The img/gt_semantic_seg pair of CustomDataset should be of the same |
| | except suffix. A valid img/gt_semantic_seg filename pair should be like |
| | ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included |
| | in the suffix). If split is given, then ``xxx`` is specified in txt file. |
| | Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. |
| | Please refer to ``docs/tutorials/new_dataset.md`` for more details. |
| | |
| | |
| | Args: |
| | pipeline (list[dict]): Processing pipeline |
| | img_dir (str): Path to image directory |
| | img_suffix (str): Suffix of images. Default: '.jpg' |
| | ann_dir (str, optional): Path to annotation directory. Default: None |
| | seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' |
| | split (str, optional): Split txt file. If split is specified, only |
| | file with suffix in the splits will be loaded. Otherwise, all |
| | images in img_dir/ann_dir will be loaded. Default: None |
| | data_root (str, optional): Data root for img_dir/ann_dir. Default: |
| | None. |
| | test_mode (bool): If test_mode=True, gt wouldn't be loaded. |
| | ignore_index (int): The label index to be ignored. Default: 255 |
| | reduce_zero_label (bool): Whether to mark label zero as ignored. |
| | Default: False |
| | classes (str | Sequence[str], optional): Specify classes to load. |
| | If is None, ``cls.CLASSES`` will be used. Default: None. |
| | palette (Sequence[Sequence[int]]] | np.ndarray | None): |
| | The palette of segmentation map. If None is given, and |
| | self.PALETTE is None, random palette will be generated. |
| | Default: None |
| | """ |
| |
|
| | ''' |
| | CLASSES = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, |
| | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, |
| | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, |
| | 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, |
| | 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, |
| | 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, |
| | 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, |
| | 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, |
| | 149, 150, 151, 152, 153, 154) |
| | ''' |
| | CLASSES = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, |
| | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, |
| | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, |
| | 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, |
| | 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, |
| | 101, 102, 103) |
| |
|
| |
|
| | PALETTE = [[0, 0, 0], [40, 100, 150], [80, 150, 200], [120, 200, 10], [160, 10, 60], |
| | [200, 60, 110], [0, 110, 160], [40, 160, 210], [80, 210, 20], [120, 20, 70], |
| | [160, 70, 120], [200, 120, 170], [0, 170, 220], [40, 220, 30], [80, 30, 80], |
| | [120, 80, 130], [160, 130, 180], [200, 180, 230], [0, 230, 40], [40, 40, 90], |
| | [80, 90, 140], [120, 140, 190], [160, 190, 0], [200, 0, 50], [0, 50, 100], |
| | [40, 100, 150], [80, 150, 200], [120, 200, 10], [160, 10, 60], [200, 60, 110], |
| | [0, 110, 160], [40, 160, 210], [80, 210, 20], [120, 20, 70], [160, 70, 120], |
| | [200, 120, 170], [0, 170, 220], [40, 220, 30], [80, 30, 80], [120, 80, 130], |
| | [160, 130, 180], [200, 180, 230], [0, 230, 40], [40, 40, 90], [80, 90, 140], |
| | [120, 140, 190], [160, 190, 0], [200, 0, 50], [0, 50, 100], [40, 100, 150], |
| | [80, 150, 200], [120, 200, 10], [160, 10, 60], [200, 60, 110], [0, 110, 160], |
| | [40, 160, 210], [80, 210, 20], [120, 20, 70], [160, 70, 120], [200, 120, 170], |
| | [0, 170, 220], [40, 220, 30], [80, 30, 80], [120, 80, 130], [160, 130, 180], |
| | [200, 180, 230], [0, 230, 40], [40, 40, 90], [80, 90, 140], [120, 140, 190], |
| | [160, 190, 0], [200, 0, 50], [0, 50, 100], [40, 100, 150], [80, 150, 200], |
| | [120, 200, 10], [160, 10, 60], [200, 60, 110], [0, 110, 160], [40, 160, 210], |
| | [80, 210, 20], [120, 20, 70], [160, 70, 120], [200, 120, 170], [0, 170, 220], |
| | [40, 220, 30], [80, 30, 80], [120, 80, 130], [160, 130, 180], [200, 180, 230], |
| | [0, 230, 40], [40, 40, 90], [80, 90, 140], [120, 140, 190], [160, 190, 0], |
| | [200, 0, 50], [0, 50, 100], [40, 100, 150], [80, 150, 200], [120, 200, 10], |
| | [160, 10, 60], [200, 60, 110], [0, 110, 160], [40, 160, 210]] |
| |
|
| | def __init__(self, |
| | pipeline, |
| | img_dir, |
| | img_suffix='.jpg', |
| | ann_dir=None, |
| | seg_map_suffix='.png', |
| | split=None, |
| | data_root=None, |
| | test_mode=False, |
| | ignore_index=255, |
| | reduce_zero_label=False, |
| | classes=None, |
| | palette=None): |
| | self.pipeline = Compose(pipeline) |
| | self.img_dir = img_dir |
| | self.img_suffix = img_suffix |
| | self.ann_dir = ann_dir |
| | self.seg_map_suffix = seg_map_suffix |
| | self.split = split |
| | self.data_root = data_root |
| | self.test_mode = test_mode |
| | self.ignore_index = ignore_index |
| | self.reduce_zero_label = reduce_zero_label |
| | self.label_map = None |
| | self.CLASSES, self.PALETTE = self.get_classes_and_palette( |
| | classes, palette) |
| |
|
| | |
| | if self.data_root is not None: |
| | if not osp.isabs(self.img_dir): |
| | self.img_dir = osp.join(self.data_root, self.img_dir) |
| | if not (self.ann_dir is None or osp.isabs(self.ann_dir)): |
| | self.ann_dir = osp.join(self.data_root, self.ann_dir) |
| | if not (self.split is None or osp.isabs(self.split)): |
| | self.split = osp.join(self.data_root, self.split) |
| |
|
| | |
| | self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, |
| | self.ann_dir, |
| | self.seg_map_suffix, self.split) |
| |
|
| | def __len__(self): |
| | """Total number of samples of data.""" |
| | return len(self.img_infos) |
| |
|
| | def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, |
| | split): |
| | """Load annotation from directory. |
| | |
| | Args: |
| | img_dir (str): Path to image directory |
| | img_suffix (str): Suffix of images. |
| | ann_dir (str|None): Path to annotation directory. |
| | seg_map_suffix (str|None): Suffix of segmentation maps. |
| | split (str|None): Split txt file. If split is specified, only file |
| | with suffix in the splits will be loaded. Otherwise, all images |
| | in img_dir/ann_dir will be loaded. Default: None |
| | |
| | Returns: |
| | list[dict]: All image info of dataset. |
| | """ |
| |
|
| | img_infos = [] |
| | if split is not None: |
| | with open(split) as f: |
| | for line in f: |
| | img_name = line.strip() |
| | img_info = dict(filename=img_name + img_suffix) |
| | if ann_dir is not None: |
| | seg_map = img_name + seg_map_suffix |
| | img_info['ann'] = dict(seg_map=seg_map) |
| | img_infos.append(img_info) |
| | else: |
| | for img in mmcv.scandir(img_dir, img_suffix, recursive=True): |
| | img_info = dict(filename=img) |
| | if ann_dir is not None: |
| | seg_map = img.replace(img_suffix, seg_map_suffix) |
| | img_info['ann'] = dict(seg_map=seg_map) |
| | img_infos.append(img_info) |
| |
|
| | print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) |
| | return img_infos |
| |
|
| | def get_ann_info(self, idx): |
| | """Get annotation by index. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | dict: Annotation info of specified index. |
| | """ |
| |
|
| | return self.img_infos[idx]['ann'] |
| |
|
| | def pre_pipeline(self, results): |
| | """Prepare results dict for pipeline.""" |
| | results['seg_fields'] = [] |
| | results['img_prefix'] = self.img_dir |
| | results['seg_prefix'] = self.ann_dir |
| | if self.custom_classes: |
| | results['label_map'] = self.label_map |
| |
|
| | def __getitem__(self, idx): |
| | """Get training/test data after pipeline. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | dict: Training/test data (with annotation if `test_mode` is set |
| | False). |
| | """ |
| |
|
| | if self.test_mode: |
| | return self.prepare_test_img(idx) |
| | else: |
| | return self.prepare_train_img(idx) |
| |
|
| | def prepare_train_img(self, idx): |
| | """Get training data and annotations after pipeline. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | dict: Training data and annotation after pipeline with new keys |
| | introduced by pipeline. |
| | """ |
| |
|
| | img_info = self.img_infos[idx] |
| | ann_info = self.get_ann_info(idx) |
| | results = dict(img_info=img_info, ann_info=ann_info) |
| | self.pre_pipeline(results) |
| | return self.pipeline(results) |
| |
|
| | def prepare_test_img(self, idx): |
| | """Get testing data after pipeline. |
| | |
| | Args: |
| | idx (int): Index of data. |
| | |
| | Returns: |
| | dict: Testing data after pipeline with new keys intorduced by |
| | piepline. |
| | """ |
| |
|
| | img_info = self.img_infos[idx] |
| | results = dict(img_info=img_info) |
| | self.pre_pipeline(results) |
| | return self.pipeline(results) |
| |
|
| | def format_results(self, results, **kwargs): |
| | """Place holder to format result to dataset specific output.""" |
| | pass |
| |
|
| | def get_gt_seg_maps(self, efficient_test=False): |
| | """Get ground truth segmentation maps for evaluation.""" |
| | gt_seg_maps = [] |
| | for img_info in self.img_infos: |
| | seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) |
| | if efficient_test: |
| | gt_seg_map = seg_map |
| | else: |
| | gt_seg_map = mmcv.imread( |
| | seg_map, flag='unchanged', backend='pillow') |
| | gt_seg_maps.append(gt_seg_map) |
| | return gt_seg_maps |
| |
|
| | def get_classes_and_palette(self, classes=None, palette=None): |
| | """Get class names of current dataset. |
| | |
| | Args: |
| | classes (Sequence[str] | str | None): If classes is None, use |
| | default CLASSES defined by builtin dataset. If classes is a |
| | string, take it as a file name. The file contains the name of |
| | classes where each line contains one class name. If classes is |
| | a tuple or list, override the CLASSES defined by the dataset. |
| | palette (Sequence[Sequence[int]]] | np.ndarray | None): |
| | The palette of segmentation map. If None is given, random |
| | palette will be generated. Default: None |
| | """ |
| | if classes is None: |
| | self.custom_classes = False |
| | return self.CLASSES, self.PALETTE |
| |
|
| | self.custom_classes = True |
| | if isinstance(classes, str): |
| | |
| | class_names = mmcv.list_from_file(classes) |
| | elif isinstance(classes, (tuple, list)): |
| | class_names = classes |
| | else: |
| | raise ValueError(f'Unsupported type {type(classes)} of classes.') |
| |
|
| | if self.CLASSES: |
| | if not set(classes).issubset(self.CLASSES): |
| | raise ValueError('classes is not a subset of CLASSES.') |
| |
|
| | |
| | |
| | |
| | self.label_map = {} |
| | for i, c in enumerate(self.CLASSES): |
| | if c not in class_names: |
| | self.label_map[i] = -1 |
| | else: |
| | self.label_map[i] = classes.index(c) |
| |
|
| | palette = self.get_palette_for_custom_classes(class_names, palette) |
| |
|
| | return class_names, palette |
| |
|
| | def get_palette_for_custom_classes(self, class_names, palette=None): |
| |
|
| | if self.label_map is not None: |
| | |
| | palette = [] |
| | for old_id, new_id in sorted( |
| | self.label_map.items(), key=lambda x: x[1]): |
| | if new_id != -1: |
| | palette.append(self.PALETTE[old_id]) |
| | palette = type(self.PALETTE)(palette) |
| |
|
| | elif palette is None: |
| | if self.PALETTE is None: |
| | palette = np.random.randint(0, 255, size=(len(class_names), 3)) |
| | else: |
| | palette = self.PALETTE |
| |
|
| | return palette |
| |
|
| | def evaluate(self, |
| | results, |
| | metric='mIoU', |
| | logger=None, |
| | efficient_test=False, |
| | **kwargs): |
| | """Evaluate the dataset. |
| | |
| | Args: |
| | results (list): Testing results of the dataset. |
| | metric (str | list[str]): Metrics to be evaluated. 'mIoU' and |
| | 'mDice' are supported. |
| | logger (logging.Logger | None | str): Logger used for printing |
| | related information during evaluation. Default: None. |
| | |
| | Returns: |
| | dict[str, float]: Default metrics. |
| | """ |
| |
|
| | if isinstance(metric, str): |
| | metric = [metric] |
| | allowed_metrics = ['mIoU', 'mDice'] |
| | if not set(metric).issubset(set(allowed_metrics)): |
| | raise KeyError('metric {} is not supported'.format(metric)) |
| | eval_results = {} |
| | gt_seg_maps = self.get_gt_seg_maps(efficient_test) |
| | if self.CLASSES is None: |
| | num_classes = len( |
| | reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) |
| | else: |
| | num_classes = len(self.CLASSES) |
| | ret_metrics = eval_metrics( |
| | results, |
| | gt_seg_maps, |
| | num_classes, |
| | self.ignore_index, |
| | metric, |
| | label_map=self.label_map, |
| | reduce_zero_label=self.reduce_zero_label) |
| | class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] |
| | if self.CLASSES is None: |
| | class_names = tuple(range(num_classes)) |
| | else: |
| | class_names = self.CLASSES |
| | ret_metrics_round = [ |
| | np.round(ret_metric * 100, 2) for ret_metric in ret_metrics |
| | ] |
| | for i in range(num_classes): |
| | class_table_data.append([class_names[i]] + |
| | [m[i] for m in ret_metrics_round[2:]] + |
| | [ret_metrics_round[1][i]]) |
| | summary_table_data = [['Scope'] + |
| | ['m' + head |
| | for head in class_table_data[0][1:]] + ['aAcc']] |
| | ret_metrics_mean = [ |
| | np.round(np.nanmean(ret_metric) * 100, 2) |
| | for ret_metric in ret_metrics |
| | ] |
| | summary_table_data.append(['global'] + ret_metrics_mean[2:] + |
| | [ret_metrics_mean[1]] + |
| | [ret_metrics_mean[0]]) |
| | print_log('per class results:', logger) |
| | table = AsciiTable(class_table_data) |
| | print_log('\n' + table.table, logger=logger) |
| | print_log('Summary:', logger) |
| | table = AsciiTable(summary_table_data) |
| | print_log('\n' + table.table, logger=logger) |
| |
|
| | for i in range(1, len(summary_table_data[0])): |
| | eval_results[summary_table_data[0] |
| | [i]] = summary_table_data[1][i] / 100.0 |
| | if mmcv.is_list_of(results, str): |
| | for file_name in results: |
| | os.remove(file_name) |
| | return eval_results |
| |
|