| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import torch |
| | from fairseq.data import MonolingualDataset |
| | from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig |
| | from tests import utils as test_utils |
| |
|
| |
|
| | class TestLMContextWindow(unittest.TestCase): |
| |
|
| | def test_eval_dataloader(self): |
| | dictionary = test_utils.dummy_dictionary(10) |
| | assert len(dictionary) == 14 |
| | assert dictionary.pad() == 1 |
| |
|
| | dataset = test_utils.TestDataset([ |
| | torch.tensor([4, 5, 6, 7], dtype=torch.long), |
| | torch.tensor([8, 9, 10, 11], dtype=torch.long), |
| | torch.tensor([12, 13], dtype=torch.long), |
| | ]) |
| | dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) |
| |
|
| | config = LanguageModelingConfig(tokens_per_sample=4) |
| | task = LanguageModelingTask(config, dictionary) |
| |
|
| | eval_dataloader = task.eval_lm_dataloader( |
| | dataset=dataset, |
| | batch_size=1, |
| | context_window=2, |
| | ) |
| |
|
| | batch = next(eval_dataloader) |
| | assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] |
| | assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] |
| |
|
| | batch = next(eval_dataloader) |
| | assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] |
| | assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] |
| |
|
| | batch = next(eval_dataloader) |
| | assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] |
| | assert batch["target"][0].tolist() == [1, 1, 12, 13] |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|