| |
| |
| |
| |
|
|
| import unittest |
| from collections import OrderedDict |
|
|
| import numpy as np |
| import torch |
| from fairseq.data import LanguagePairDataset, TokenBlockDataset |
| from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset |
| from tests.test_train import mock_dict |
|
|
|
|
| class TestMultiCorpusSampledDataset(unittest.TestCase): |
| def setUp(self): |
| d = mock_dict() |
| tokens_1 = torch.LongTensor([1]).view(1, -1) |
| tokens_ds1 = TokenBlockDataset( |
| tokens_1, |
| sizes=[tokens_1.size(-1)], |
| block_size=1, |
| pad=0, |
| eos=1, |
| include_targets=False, |
| ) |
| self.dataset_1 = LanguagePairDataset( |
| tokens_ds1, tokens_ds1.sizes, d, shuffle=False |
| ) |
| tokens_2 = torch.LongTensor([2]).view(1, -1) |
| tokens_ds2 = TokenBlockDataset( |
| tokens_2, |
| sizes=[tokens_2.size(-1)], |
| block_size=1, |
| pad=0, |
| eos=1, |
| include_targets=False, |
| ) |
| self.dataset_2 = LanguagePairDataset( |
| tokens_ds2, tokens_ds2.sizes, d, shuffle=False |
| ) |
|
|
| def _test_sample_helper( |
| self, |
| expected_sample_from_first_ds_percentage, |
| num_samples=1000, |
| sampling_func=None, |
| ): |
| |
| np.random.seed(0) |
| if sampling_func is None: |
| m = MultiCorpusSampledDataset( |
| OrderedDict({0: self.dataset_1, 1: self.dataset_2}), |
| ) |
| else: |
| m = MultiCorpusSampledDataset( |
| OrderedDict({0: self.dataset_1, 1: self.dataset_2}), |
| sampling_func=sampling_func, |
| ) |
| m.ordered_indices() |
| count_sample_from_first_dataset = 0 |
| for _ in range(num_samples): |
| if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1: |
| count_sample_from_first_dataset += 1 |
| sample_from_first_ds_percentage = ( |
| 1.0 * count_sample_from_first_dataset / num_samples |
| ) |
| self.assertLess( |
| abs( |
| sample_from_first_ds_percentage |
| - expected_sample_from_first_ds_percentage |
| ), |
| 0.01, |
| ) |
|
|
| def test_multi_corpus_sampled_dataset_uniform_sample(self): |
| self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5) |
|
|
| def test_multi_corpus_sampled_dataset_weighted_sample(self): |
| def naive_weighted_sample(weights): |
| def f(input): |
| v = np.random.random() |
| agg = 0 |
| for i, weight in enumerate(weights): |
| agg += weight |
| if agg > v: |
| return i |
|
|
| return f |
|
|
| self._test_sample_helper( |
| expected_sample_from_first_ds_percentage=0.9, |
| sampling_func=naive_weighted_sample(weights=[0.9, 0.1]), |
| ) |
|
|