|
|
|
|
|
import os.path as osp |
|
|
|
|
|
from mmengine.dataset import ConcatDataset, RepeatDataset |
|
|
from mmengine.registry import init_default_scope |
|
|
|
|
|
from mmseg.datasets import MultiImageMixDataset |
|
|
from mmseg.registry import DATASETS |
|
|
|
|
|
init_default_scope('mmseg') |
|
|
|
|
|
|
|
|
@DATASETS.register_module() |
|
|
class ToyDataset: |
|
|
|
|
|
def __init__(self, cnt=0): |
|
|
self.cnt = cnt |
|
|
|
|
|
def __item__(self, idx): |
|
|
return idx |
|
|
|
|
|
def __len__(self): |
|
|
return 100 |
|
|
|
|
|
|
|
|
def test_build_dataset(): |
|
|
cfg = dict(type='ToyDataset') |
|
|
dataset = DATASETS.build(cfg) |
|
|
assert isinstance(dataset, ToyDataset) |
|
|
assert dataset.cnt == 0 |
|
|
dataset = DATASETS.build(cfg, default_args=dict(cnt=1)) |
|
|
assert isinstance(dataset, ToyDataset) |
|
|
assert dataset.cnt == 1 |
|
|
|
|
|
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') |
|
|
data_prefix = dict(img_path='imgs/', seg_map_path='gts/') |
|
|
|
|
|
|
|
|
cfg = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=data_prefix, |
|
|
serialize_data=False) |
|
|
dataset = DATASETS.build(cfg) |
|
|
dataset_repeat = RepeatDataset(dataset=dataset, times=5) |
|
|
assert isinstance(dataset_repeat, RepeatDataset) |
|
|
assert len(dataset_repeat) == 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg1 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=data_prefix, |
|
|
serialize_data=False) |
|
|
cfg2 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=data_prefix, |
|
|
serialize_data=False) |
|
|
dataset1 = DATASETS.build(cfg1) |
|
|
dataset2 = DATASETS.build(cfg2) |
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) |
|
|
assert isinstance(dataset_concat, ConcatDataset) |
|
|
assert len(dataset_concat) == 10 |
|
|
|
|
|
|
|
|
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[]) |
|
|
assert isinstance(dataset, MultiImageMixDataset) |
|
|
assert len(dataset) == 10 |
|
|
|
|
|
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2]) |
|
|
|
|
|
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[]) |
|
|
assert isinstance(dataset, MultiImageMixDataset) |
|
|
assert len(dataset) == 10 |
|
|
|
|
|
|
|
|
cfg1 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=data_prefix, |
|
|
ann_file='splits/train.txt', |
|
|
serialize_data=False) |
|
|
cfg2 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=data_prefix, |
|
|
ann_file='splits/val.txt', |
|
|
serialize_data=False) |
|
|
|
|
|
dataset1 = DATASETS.build(cfg1) |
|
|
dataset2 = DATASETS.build(cfg2) |
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) |
|
|
assert isinstance(dataset_concat, ConcatDataset) |
|
|
assert len(dataset_concat) == 5 |
|
|
|
|
|
|
|
|
cfg1 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=dict(img_path='imgs/'), |
|
|
test_mode=True, |
|
|
metainfo=dict(classes=('pseudo_class', )), |
|
|
serialize_data=False) |
|
|
cfg2 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=dict(img_path='imgs/'), |
|
|
test_mode=True, |
|
|
metainfo=dict(classes=('pseudo_class', )), |
|
|
serialize_data=False) |
|
|
|
|
|
dataset1 = DATASETS.build(cfg1) |
|
|
dataset2 = DATASETS.build(cfg2) |
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) |
|
|
assert isinstance(dataset_concat, ConcatDataset) |
|
|
assert len(dataset_concat) == 10 |
|
|
|
|
|
|
|
|
cfg1 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=dict(img_path='imgs/'), |
|
|
ann_file='splits/val.txt', |
|
|
test_mode=True, |
|
|
metainfo=dict(classes=('pseudo_class', )), |
|
|
serialize_data=False) |
|
|
cfg2 = dict( |
|
|
type='BaseSegDataset', |
|
|
pipeline=[], |
|
|
data_root=data_root, |
|
|
data_prefix=dict(img_path='imgs/'), |
|
|
ann_file='splits/val.txt', |
|
|
test_mode=True, |
|
|
metainfo=dict(classes=('pseudo_class', )), |
|
|
serialize_data=False) |
|
|
|
|
|
dataset1 = DATASETS.build(cfg1) |
|
|
dataset2 = DATASETS.build(cfg2) |
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) |
|
|
assert isinstance(dataset_concat, ConcatDataset) |
|
|
assert len(dataset_concat) == 2 |
|
|
|