test
/
FoodSeg103
/Swin-Transformer-Semantic-Segmentation-main
/tests
/test_data
/test_dataset_builder.py
| import math | |
| import os.path as osp | |
| import pytest | |
| from torch.utils.data import (DistributedSampler, RandomSampler, | |
| SequentialSampler) | |
| from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, | |
| build_dataset) | |
| class ToyDataset(object): | |
| def __init__(self, cnt=0): | |
| self.cnt = cnt | |
| def __item__(self, idx): | |
| return idx | |
| def __len__(self): | |
| return 100 | |
| def test_build_dataset(): | |
| cfg = dict(type='ToyDataset') | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ToyDataset) | |
| assert dataset.cnt == 0 | |
| dataset = build_dataset(cfg, default_args=dict(cnt=1)) | |
| assert isinstance(dataset, ToyDataset) | |
| assert dataset.cnt == 1 | |
| data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') | |
| img_dir = 'imgs/' | |
| ann_dir = 'gts/' | |
| # We use same dir twice for simplicity | |
| # with ann_dir | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=[img_dir, img_dir], | |
| ann_dir=[ann_dir, ann_dir]) | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ConcatDataset) | |
| assert len(dataset) == 10 | |
| # with ann_dir, split | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=img_dir, | |
| ann_dir=ann_dir, | |
| split=['splits/train.txt', 'splits/val.txt']) | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ConcatDataset) | |
| assert len(dataset) == 5 | |
| # with ann_dir, split | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=img_dir, | |
| ann_dir=[ann_dir, ann_dir], | |
| split=['splits/train.txt', 'splits/val.txt']) | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ConcatDataset) | |
| assert len(dataset) == 5 | |
| # test mode | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=[img_dir, img_dir], | |
| test_mode=True) | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ConcatDataset) | |
| assert len(dataset) == 10 | |
| # test mode with splits | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=[img_dir, img_dir], | |
| split=['splits/val.txt', 'splits/val.txt'], | |
| test_mode=True) | |
| dataset = build_dataset(cfg) | |
| assert isinstance(dataset, ConcatDataset) | |
| assert len(dataset) == 2 | |
| # len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1 | |
| with pytest.raises(AssertionError): | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=[img_dir, img_dir], | |
| ann_dir=[ann_dir, ann_dir, ann_dir]) | |
| build_dataset(cfg) | |
| # len(splits) should be zero or len(img_dir) when len(img_dir) > 1 | |
| with pytest.raises(AssertionError): | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=[img_dir, img_dir], | |
| split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) | |
| build_dataset(cfg) | |
| # len(splits) == len(ann_dir) when only len(img_dir) == 1 and len( | |
| # ann_dir) > 1 | |
| with pytest.raises(AssertionError): | |
| cfg = dict( | |
| type='CustomDataset', | |
| pipeline=[], | |
| data_root=data_root, | |
| img_dir=img_dir, | |
| ann_dir=[ann_dir, ann_dir], | |
| split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) | |
| build_dataset(cfg) | |
| def test_build_dataloader(): | |
| dataset = ToyDataset() | |
| samples_per_gpu = 3 | |
| # dist=True, shuffle=True, 1GPU | |
| dataloader = build_dataloader( | |
| dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2) | |
| assert dataloader.batch_size == samples_per_gpu | |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) | |
| assert isinstance(dataloader.sampler, DistributedSampler) | |
| assert dataloader.sampler.shuffle | |
| # dist=True, shuffle=False, 1GPU | |
| dataloader = build_dataloader( | |
| dataset, | |
| samples_per_gpu=samples_per_gpu, | |
| workers_per_gpu=2, | |
| shuffle=False) | |
| assert dataloader.batch_size == samples_per_gpu | |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) | |
| assert isinstance(dataloader.sampler, DistributedSampler) | |
| assert not dataloader.sampler.shuffle | |
| # dist=True, shuffle=True, 8GPU | |
| dataloader = build_dataloader( | |
| dataset, | |
| samples_per_gpu=samples_per_gpu, | |
| workers_per_gpu=2, | |
| num_gpus=8) | |
| assert dataloader.batch_size == samples_per_gpu | |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) | |
| assert dataloader.num_workers == 2 | |
| # dist=False, shuffle=True, 1GPU | |
| dataloader = build_dataloader( | |
| dataset, | |
| samples_per_gpu=samples_per_gpu, | |
| workers_per_gpu=2, | |
| dist=False) | |
| assert dataloader.batch_size == samples_per_gpu | |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) | |
| assert isinstance(dataloader.sampler, RandomSampler) | |
| assert dataloader.num_workers == 2 | |
| # dist=False, shuffle=False, 1GPU | |
| dataloader = build_dataloader( | |
| dataset, | |
| samples_per_gpu=3, | |
| workers_per_gpu=2, | |
| shuffle=False, | |
| dist=False) | |
| assert dataloader.batch_size == samples_per_gpu | |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) | |
| assert isinstance(dataloader.sampler, SequentialSampler) | |
| assert dataloader.num_workers == 2 | |
| # dist=False, shuffle=True, 8GPU | |
| dataloader = build_dataloader( | |
| dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False) | |
| assert dataloader.batch_size == samples_per_gpu * 8 | |
| assert len(dataloader) == int( | |
| math.ceil(len(dataset) / samples_per_gpu / 8)) | |
| assert isinstance(dataloader.sampler, RandomSampler) | |
| assert dataloader.num_workers == 16 | |