camenduru's picture
thanks to show ❤
3bbb319
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import math
from collections import defaultdict
from unittest.mock import MagicMock
import numpy as np
import pytest
from mmdet.datasets import (ClassBalancedDataset, ConcatDataset, CustomDataset,
MultiImageMixDataset, RepeatDataset)
def test_dataset_wrapper():
CustomDataset.load_annotations = MagicMock()
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = CustomDataset(
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
len_a = 10
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
ann_info_list_a = []
for _ in range(len_a):
height = np.random.randint(10, 30)
weight = np.random.randint(10, 30)
img = np.ones((height, weight, 3))
gt_bbox = np.concatenate([
np.random.randint(1, 5, (2, 2)),
np.random.randint(1, 5, (2, 2)) + 5
],
axis=1)
gt_labels = np.random.randint(0, 80, 2)
ann_info_list_a.append(
dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])
dataset_a.get_ann_info = MagicMock(
side_effect=lambda idx: ann_info_list_a[idx])
dataset_b = CustomDataset(
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
len_b = 20
cat_ids_list_b = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_b)
]
ann_info_list_b = []
for _ in range(len_b):
height = np.random.randint(10, 30)
weight = np.random.randint(10, 30)
img = np.ones((height, weight, 3))
gt_bbox = np.concatenate([
np.random.randint(1, 5, (2, 2)),
np.random.randint(1, 5, (2, 2)) + 5
],
axis=1)
gt_labels = np.random.randint(0, 80, 2)
ann_info_list_b.append(
dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
dataset_b.data_infos = MagicMock()
dataset_b.data_infos.__len__.return_value = len_b
dataset_b.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_b[idx])
dataset_b.get_ann_info = MagicMock(
side_effect=lambda idx: ann_info_list_b[idx])
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert concat_dataset.get_ann_info(5) == ann_info_list_a[5]
assert concat_dataset.get_ann_info(25) == ann_info_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
# Test if ConcatDataset allows dataset classes without the PALETTE
# attribute
palette_backup = CustomDataset.PALETTE
delattr(CustomDataset, 'PALETTE')
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
assert repeat_dataset.get_ann_info(5) == ann_info_list_a[5]
assert repeat_dataset.get_ann_info(15) == ann_info_list_a[5]
assert repeat_dataset.get_ann_info(27) == ann_info_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)
# Test if RepeatDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup
category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list_a)
mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
repeat_factors = []
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)
assert repeat_factor_dataset.get_ann_info(idx) == ann_info_list_a[
bisect.bisect_right(repeat_factors_cumsum, idx)]
# Test if ClassBalancedDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert repeat_factor_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup
img_scale = (60, 60)
pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(type='Pad', pad_to_square=True, pad_val=114.0),
]
CustomDataset.load_annotations = MagicMock()
results = []
for _ in range(2):
height = np.random.randint(10, 30)
weight = np.random.randint(10, 30)
img = np.ones((height, weight, 3))
gt_bbox = np.concatenate([
np.random.randint(1, 5, (2, 2)),
np.random.randint(1, 5, (2, 2)) + 5
],
axis=1)
gt_labels = np.random.randint(0, 80, 2)
results.append(dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
dataset_a = CustomDataset(
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
len_a = 2
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])
# test dynamic_scale deprecated
with pytest.raises(RuntimeError):
MultiImageMixDataset(dataset_a, pipeline, (80, 80))
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
# test skip_type_keys
multi_image_mix_dataset = MultiImageMixDataset(
dataset_a,
pipeline,
skip_type_keys=('MixUp', 'RandomFlip', 'Resize', 'Pad'))
for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
# Test if MultiImageMixDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
assert multi_image_mix_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup