| import os |
| import os.path as osp |
| from collections import OrderedDict |
| from functools import reduce |
|
|
| import annotator.uniformer.mmcv as mmcv |
| import numpy as np |
| from annotator.uniformer.mmcv.utils import print_log |
| from prettytable import PrettyTable |
| from torch.utils.data import Dataset |
|
|
| from annotator.uniformer.mmseg.core import eval_metrics |
| from annotator.uniformer.mmseg.utils import get_root_logger |
| from .builder import DATASETS |
| from .pipelines import Compose |
|
|
|
|
| @DATASETS.register_module() |
| class CustomDataset(Dataset): |
| """Custom dataset for semantic segmentation. An example of file structure |
| is as followed. |
| |
| .. code-block:: none |
| |
| ├── data |
| │ ├── my_dataset |
| │ │ ├── img_dir |
| │ │ │ ├── train |
| │ │ │ │ ├── xxx{img_suffix} |
| │ │ │ │ ├── yyy{img_suffix} |
| │ │ │ │ ├── zzz{img_suffix} |
| │ │ │ ├── val |
| │ │ ├── ann_dir |
| │ │ │ ├── train |
| │ │ │ │ ├── xxx{seg_map_suffix} |
| │ │ │ │ ├── yyy{seg_map_suffix} |
| │ │ │ │ ├── zzz{seg_map_suffix} |
| │ │ │ ├── val |
| |
| The img/gt_semantic_seg pair of CustomDataset should be of the same |
| except suffix. A valid img/gt_semantic_seg filename pair should be like |
| ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included |
| in the suffix). If split is given, then ``xxx`` is specified in txt file. |
| Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. |
| Please refer to ``docs/tutorials/new_dataset.md`` for more details. |
| |
| |
| Args: |
| pipeline (list[dict]): Processing pipeline |
| img_dir (str): Path to image directory |
| img_suffix (str): Suffix of images. Default: '.jpg' |
| ann_dir (str, optional): Path to annotation directory. Default: None |
| seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' |
| split (str, optional): Split txt file. If split is specified, only |
| file with suffix in the splits will be loaded. Otherwise, all |
| images in img_dir/ann_dir will be loaded. Default: None |
| data_root (str, optional): Data root for img_dir/ann_dir. Default: |
| None. |
| test_mode (bool): If test_mode=True, gt wouldn't be loaded. |
| ignore_index (int): The label index to be ignored. Default: 255 |
| reduce_zero_label (bool): Whether to mark label zero as ignored. |
| Default: False |
| classes (str | Sequence[str], optional): Specify classes to load. |
| If is None, ``cls.CLASSES`` will be used. Default: None. |
| palette (Sequence[Sequence[int]]] | np.ndarray | None): |
| The palette of segmentation map. If None is given, and |
| self.PALETTE is None, random palette will be generated. |
| Default: None |
| """ |
|
|
| CLASSES = None |
|
|
| PALETTE = None |
|
|
| def __init__(self, |
| pipeline, |
| img_dir, |
| img_suffix='.jpg', |
| ann_dir=None, |
| seg_map_suffix='.png', |
| split=None, |
| data_root=None, |
| test_mode=False, |
| ignore_index=255, |
| reduce_zero_label=False, |
| classes=None, |
| palette=None): |
| self.pipeline = Compose(pipeline) |
| self.img_dir = img_dir |
| self.img_suffix = img_suffix |
| self.ann_dir = ann_dir |
| self.seg_map_suffix = seg_map_suffix |
| self.split = split |
| self.data_root = data_root |
| self.test_mode = test_mode |
| self.ignore_index = ignore_index |
| self.reduce_zero_label = reduce_zero_label |
| self.label_map = None |
| self.CLASSES, self.PALETTE = self.get_classes_and_palette( |
| classes, palette) |
|
|
| |
| if self.data_root is not None: |
| if not osp.isabs(self.img_dir): |
| self.img_dir = osp.join(self.data_root, self.img_dir) |
| if not (self.ann_dir is None or osp.isabs(self.ann_dir)): |
| self.ann_dir = osp.join(self.data_root, self.ann_dir) |
| if not (self.split is None or osp.isabs(self.split)): |
| self.split = osp.join(self.data_root, self.split) |
|
|
| |
| self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, |
| self.ann_dir, |
| self.seg_map_suffix, self.split) |
|
|
| def __len__(self): |
| """Total number of samples of data.""" |
| return len(self.img_infos) |
|
|
| def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, |
| split): |
| """Load annotation from directory. |
| |
| Args: |
| img_dir (str): Path to image directory |
| img_suffix (str): Suffix of images. |
| ann_dir (str|None): Path to annotation directory. |
| seg_map_suffix (str|None): Suffix of segmentation maps. |
| split (str|None): Split txt file. If split is specified, only file |
| with suffix in the splits will be loaded. Otherwise, all images |
| in img_dir/ann_dir will be loaded. Default: None |
| |
| Returns: |
| list[dict]: All image info of dataset. |
| """ |
|
|
| img_infos = [] |
| if split is not None: |
| with open(split) as f: |
| for line in f: |
| img_name = line.strip() |
| img_info = dict(filename=img_name + img_suffix) |
| if ann_dir is not None: |
| seg_map = img_name + seg_map_suffix |
| img_info['ann'] = dict(seg_map=seg_map) |
| img_infos.append(img_info) |
| else: |
| for img in mmcv.scandir(img_dir, img_suffix, recursive=True): |
| img_info = dict(filename=img) |
| if ann_dir is not None: |
| seg_map = img.replace(img_suffix, seg_map_suffix) |
| img_info['ann'] = dict(seg_map=seg_map) |
| img_infos.append(img_info) |
|
|
| print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) |
| return img_infos |
|
|
| def get_ann_info(self, idx): |
| """Get annotation by index. |
| |
| Args: |
| idx (int): Index of data. |
| |
| Returns: |
| dict: Annotation info of specified index. |
| """ |
|
|
| return self.img_infos[idx]['ann'] |
|
|
| def pre_pipeline(self, results): |
| """Prepare results dict for pipeline.""" |
| results['seg_fields'] = [] |
| results['img_prefix'] = self.img_dir |
| results['seg_prefix'] = self.ann_dir |
| if self.custom_classes: |
| results['label_map'] = self.label_map |
|
|
| def __getitem__(self, idx): |
| """Get training/test data after pipeline. |
| |
| Args: |
| idx (int): Index of data. |
| |
| Returns: |
| dict: Training/test data (with annotation if `test_mode` is set |
| False). |
| """ |
|
|
| if self.test_mode: |
| return self.prepare_test_img(idx) |
| else: |
| return self.prepare_train_img(idx) |
|
|
| def prepare_train_img(self, idx): |
| """Get training data and annotations after pipeline. |
| |
| Args: |
| idx (int): Index of data. |
| |
| Returns: |
| dict: Training data and annotation after pipeline with new keys |
| introduced by pipeline. |
| """ |
|
|
| img_info = self.img_infos[idx] |
| ann_info = self.get_ann_info(idx) |
| results = dict(img_info=img_info, ann_info=ann_info) |
| self.pre_pipeline(results) |
| return self.pipeline(results) |
|
|
| def prepare_test_img(self, idx): |
| """Get testing data after pipeline. |
| |
| Args: |
| idx (int): Index of data. |
| |
| Returns: |
| dict: Testing data after pipeline with new keys introduced by |
| pipeline. |
| """ |
|
|
| img_info = self.img_infos[idx] |
| results = dict(img_info=img_info) |
| self.pre_pipeline(results) |
| return self.pipeline(results) |
|
|
| def format_results(self, results, **kwargs): |
| """Place holder to format result to dataset specific output.""" |
|
|
| def get_gt_seg_maps(self, efficient_test=False): |
| """Get ground truth segmentation maps for evaluation.""" |
| gt_seg_maps = [] |
| for img_info in self.img_infos: |
| seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) |
| if efficient_test: |
| gt_seg_map = seg_map |
| else: |
| gt_seg_map = mmcv.imread( |
| seg_map, flag='unchanged', backend='pillow') |
| gt_seg_maps.append(gt_seg_map) |
| return gt_seg_maps |
|
|
| def get_classes_and_palette(self, classes=None, palette=None): |
| """Get class names of current dataset. |
| |
| Args: |
| classes (Sequence[str] | str | None): If classes is None, use |
| default CLASSES defined by builtin dataset. If classes is a |
| string, take it as a file name. The file contains the name of |
| classes where each line contains one class name. If classes is |
| a tuple or list, override the CLASSES defined by the dataset. |
| palette (Sequence[Sequence[int]]] | np.ndarray | None): |
| The palette of segmentation map. If None is given, random |
| palette will be generated. Default: None |
| """ |
| if classes is None: |
| self.custom_classes = False |
| return self.CLASSES, self.PALETTE |
|
|
| self.custom_classes = True |
| if isinstance(classes, str): |
| |
| class_names = mmcv.list_from_file(classes) |
| elif isinstance(classes, (tuple, list)): |
| class_names = classes |
| else: |
| raise ValueError(f'Unsupported type {type(classes)} of classes.') |
|
|
| if self.CLASSES: |
| if not set(classes).issubset(self.CLASSES): |
| raise ValueError('classes is not a subset of CLASSES.') |
|
|
| |
| |
| |
| self.label_map = {} |
| for i, c in enumerate(self.CLASSES): |
| if c not in class_names: |
| self.label_map[i] = -1 |
| else: |
| self.label_map[i] = classes.index(c) |
|
|
| palette = self.get_palette_for_custom_classes(class_names, palette) |
|
|
| return class_names, palette |
|
|
| def get_palette_for_custom_classes(self, class_names, palette=None): |
|
|
| if self.label_map is not None: |
| |
| palette = [] |
| for old_id, new_id in sorted( |
| self.label_map.items(), key=lambda x: x[1]): |
| if new_id != -1: |
| palette.append(self.PALETTE[old_id]) |
| palette = type(self.PALETTE)(palette) |
|
|
| elif palette is None: |
| if self.PALETTE is None: |
| palette = np.random.randint(0, 255, size=(len(class_names), 3)) |
| else: |
| palette = self.PALETTE |
|
|
| return palette |
|
|
| def evaluate(self, |
| results, |
| metric='mIoU', |
| logger=None, |
| efficient_test=False, |
| **kwargs): |
| """Evaluate the dataset. |
| |
| Args: |
| results (list): Testing results of the dataset. |
| metric (str | list[str]): Metrics to be evaluated. 'mIoU', |
| 'mDice' and 'mFscore' are supported. |
| logger (logging.Logger | None | str): Logger used for printing |
| related information during evaluation. Default: None. |
| |
| Returns: |
| dict[str, float]: Default metrics. |
| """ |
|
|
| if isinstance(metric, str): |
| metric = [metric] |
| allowed_metrics = ['mIoU', 'mDice', 'mFscore'] |
| if not set(metric).issubset(set(allowed_metrics)): |
| raise KeyError('metric {} is not supported'.format(metric)) |
| eval_results = {} |
| gt_seg_maps = self.get_gt_seg_maps(efficient_test) |
| if self.CLASSES is None: |
| num_classes = len( |
| reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) |
| else: |
| num_classes = len(self.CLASSES) |
| ret_metrics = eval_metrics( |
| results, |
| gt_seg_maps, |
| num_classes, |
| self.ignore_index, |
| metric, |
| label_map=self.label_map, |
| reduce_zero_label=self.reduce_zero_label) |
|
|
| if self.CLASSES is None: |
| class_names = tuple(range(num_classes)) |
| else: |
| class_names = self.CLASSES |
|
|
| |
| ret_metrics_summary = OrderedDict({ |
| ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) |
| for ret_metric, ret_metric_value in ret_metrics.items() |
| }) |
|
|
| |
| ret_metrics.pop('aAcc', None) |
| ret_metrics_class = OrderedDict({ |
| ret_metric: np.round(ret_metric_value * 100, 2) |
| for ret_metric, ret_metric_value in ret_metrics.items() |
| }) |
| ret_metrics_class.update({'Class': class_names}) |
| ret_metrics_class.move_to_end('Class', last=False) |
|
|
| |
| class_table_data = PrettyTable() |
| for key, val in ret_metrics_class.items(): |
| class_table_data.add_column(key, val) |
|
|
| summary_table_data = PrettyTable() |
| for key, val in ret_metrics_summary.items(): |
| if key == 'aAcc': |
| summary_table_data.add_column(key, [val]) |
| else: |
| summary_table_data.add_column('m' + key, [val]) |
|
|
| print_log('per class results:', logger) |
| print_log('\n' + class_table_data.get_string(), logger=logger) |
| print_log('Summary:', logger) |
| print_log('\n' + summary_table_data.get_string(), logger=logger) |
|
|
| |
| for key, value in ret_metrics_summary.items(): |
| if key == 'aAcc': |
| eval_results[key] = value / 100.0 |
| else: |
| eval_results['m' + key] = value / 100.0 |
|
|
| ret_metrics_class.pop('Class', None) |
| for key, value in ret_metrics_class.items(): |
| eval_results.update({ |
| key + '.' + str(name): value[idx] / 100.0 |
| for idx, name in enumerate(class_names) |
| }) |
|
|
| if mmcv.is_list_of(results, str): |
| for file_name in results: |
| os.remove(file_name) |
| return eval_results |
|
|