| | |
| | import itertools |
| | import math |
| | import operator |
| | import unittest |
| | import torch |
| | from torch.utils import data |
| | from torch.utils.data.sampler import SequentialSampler |
| |
|
| | from detectron2.data.build import worker_init_reset_seed |
| | from detectron2.data.common import DatasetFromList, ToIterableDataset |
| | from detectron2.data.samplers import ( |
| | GroupedBatchSampler, |
| | InferenceSampler, |
| | RepeatFactorTrainingSampler, |
| | TrainingSampler, |
| | ) |
| | from detectron2.utils.env import seed_all_rng |
| |
|
| |
|
| | class TestGroupedBatchSampler(unittest.TestCase): |
| | def test_missing_group_id(self): |
| | sampler = SequentialSampler(list(range(100))) |
| | group_ids = [1] * 100 |
| | samples = GroupedBatchSampler(sampler, group_ids, 2) |
| |
|
| | for mini_batch in samples: |
| | self.assertEqual(len(mini_batch), 2) |
| |
|
| | def test_groups(self): |
| | sampler = SequentialSampler(list(range(100))) |
| | group_ids = [1, 0] * 50 |
| | samples = GroupedBatchSampler(sampler, group_ids, 2) |
| |
|
| | for mini_batch in samples: |
| | self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0) |
| |
|
| |
|
| | class TestSamplerDeterministic(unittest.TestCase): |
| | def test_to_iterable(self): |
| | sampler = TrainingSampler(100, seed=10) |
| | gt_output = list(itertools.islice(sampler, 100)) |
| | self.assertEqual(set(gt_output), set(range(100))) |
| |
|
| | dataset = DatasetFromList(list(range(100))) |
| | dataset = ToIterableDataset(dataset, sampler) |
| | data_loader = data.DataLoader(dataset, num_workers=0, collate_fn=operator.itemgetter(0)) |
| |
|
| | output = list(itertools.islice(data_loader, 100)) |
| | self.assertEqual(output, gt_output) |
| |
|
| | data_loader = data.DataLoader( |
| | dataset, |
| | num_workers=2, |
| | collate_fn=operator.itemgetter(0), |
| | worker_init_fn=worker_init_reset_seed, |
| | |
| | ) |
| | output = list(itertools.islice(data_loader, 100)) |
| | |
| | self.assertEqual(output, gt_output) |
| |
|
| | def test_training_sampler_seed(self): |
| | seed_all_rng(42) |
| | sampler = TrainingSampler(30) |
| | data = list(itertools.islice(sampler, 65)) |
| |
|
| | seed_all_rng(42) |
| | sampler = TrainingSampler(30) |
| | seed_all_rng(999) |
| | data2 = list(itertools.islice(sampler, 65)) |
| | self.assertEqual(data, data2) |
| |
|
| |
|
| | class TestRepeatFactorTrainingSampler(unittest.TestCase): |
| | def test_repeat_factors_from_category_frequency(self): |
| | repeat_thresh = 0.5 |
| |
|
| | dataset_dicts = [ |
| | {"annotations": [{"category_id": 0}, {"category_id": 1}]}, |
| | {"annotations": [{"category_id": 0}]}, |
| | {"annotations": []}, |
| | ] |
| |
|
| | rep_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( |
| | dataset_dicts, repeat_thresh |
| | ) |
| |
|
| | expected_rep_factors = torch.tensor([math.sqrt(3 / 2), 1.0, 1.0]) |
| | self.assertTrue(torch.allclose(rep_factors, expected_rep_factors)) |
| |
|
| |
|
| | class TestInferenceSampler(unittest.TestCase): |
| | def test_local_indices(self): |
| | sizes = [0, 16, 2, 42] |
| | world_sizes = [5, 2, 3, 4] |
| |
|
| | expected_results = [ |
| | [range(0) for _ in range(5)], |
| | [range(8), range(8, 16)], |
| | [range(1), range(1, 2), range(0)], |
| | [range(11), range(11, 22), range(22, 32), range(32, 42)], |
| | ] |
| |
|
| | for size, world_size, expected_result in zip(sizes, world_sizes, expected_results): |
| | with self.subTest(f"size={size}, world_size={world_size}"): |
| | local_indices = [ |
| | InferenceSampler._get_local_indices(size, world_size, r) |
| | for r in range(world_size) |
| | ] |
| | self.assertEqual(local_indices, expected_result) |
| |
|