File size: 4,575 Bytes
ea1014e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmengine.dataset import ConcatDataset, RepeatDataset
from mmengine.registry import init_default_scope
from mmseg.datasets import MultiImageMixDataset
from mmseg.registry import DATASETS
init_default_scope('mmseg')
@DATASETS.register_module()
class ToyDataset:
def __init__(self, cnt=0):
self.cnt = cnt
def __item__(self, idx):
return idx
def __len__(self):
return 100
def test_build_dataset():
cfg = dict(type='ToyDataset')
dataset = DATASETS.build(cfg)
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 0
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 1
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
# test RepeatDataset
cfg = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset = DATASETS.build(cfg)
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
assert isinstance(dataset_repeat, RepeatDataset)
assert len(dataset_repeat) == 25
# test ConcatDataset
# We use same dir twice for simplicity
# with data_prefix.seg_map_path
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test MultiImageMixDataset
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
# with data_prefix.seg_map_path, ann_file
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/train.txt',
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/val.txt',
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 5
# test mode
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test mode with ann_files
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 2
|