| |
| |
| |
| |
|
|
| import tempfile |
| import unittest |
| from pathlib import Path |
| from typing import Any, Dict, Sequence |
|
|
| import fairseq.data.indexed_dataset as indexed_dataset |
| import fairseq.options |
| import fairseq.tasks.online_backtranslation as obt |
| import torch |
| from tests import utils |
|
|
|
|
| def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]: |
| batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size) |
| sample = { |
| "net_input": { |
| "src_tokens": batch, |
| "prev_output_tokens": batch, |
| "src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long), |
| }, |
| "target": batch[:, 1:], |
| } |
| return sample |
|
|
|
|
| def mk_dataset(num_samples: int, max_len: int, output: Path): |
| output.parent.mkdir(exist_ok=True) |
| idx = indexed_dataset.IndexedDatasetBuilder(str(output)) |
| data = torch.randint(5, 100, (num_samples, max_len)) |
| lengths = torch.randint(3, max_len, (num_samples,)) |
| for d, l in zip(data, lengths): |
| d[0] = 0 |
| idx.add_item(d[:l]) |
| idx.finalize(output.with_suffix(".idx")) |
| assert output.exists() |
| assert output.with_suffix(".idx").exists() |
|
|
|
|
| class OnlineBacktranslationTest(unittest.TestCase): |
|
|
| tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest")) |
|
|
| @classmethod |
| def obt_task( |
| cls, languages: Sequence[str], data: Path = None, language_mapping: str = None |
| ): |
| dict_path = cls.tmp_dir / "dict.txt" |
| if not dict_path.exists(): |
| dictionary = utils.dummy_dictionary(100) |
| dictionary.save(str(dict_path)) |
|
|
| if data is not None: |
| (data / "dict.txt").write_text(dict_path.read_text()) |
| else: |
| data = cls.tmp_dir |
| assert len(languages) >= 2 |
|
|
| kwargs = { |
| "arch": "transformer", |
| |
| "max_sentences": 1, |
| |
| "encoder_layers": 3, |
| "encoder_embed_dim": 12, |
| "encoder_ffn_embed_dim": 14, |
| "encoder_attention_heads": 4, |
| "decoder_layers": 3, |
| "decoder_embed_dim": 12, |
| "decoder_output_dim": 12, |
| "decoder_ffn_embed_dim": 14, |
| "decoder_attention_heads": 4, |
| |
| "dropout": 0, |
| "attention_dropout": 0, |
| "activation_dropout": 0, |
| "encoder_layerdrop": 0, |
| } |
|
|
| args = fairseq.options.get_args( |
| data, |
| task="online_backtranslation", |
| mono_langs=",".join(languages), |
| valid_lang_pairs=f"{languages[0]}-{languages[1]}", |
| tokens_per_sample=256, |
| language_mapping=language_mapping, |
| **kwargs, |
| ) |
| task = obt.OnlineBackTranslationTask.setup_task(args) |
| |
| model = task.build_model(task.args) |
| return task, model |
|
|
| def tmp_path(self, test_case: str) -> Path: |
| return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir)) |
|
|
| def test_lang_tokens(self): |
| task, model = self.obt_task(["en", "ro", "zh"]) |
| assert obt._lang_token("en") in task.dictionary |
| assert obt._lang_token("ro") in task.dictionary |
| assert obt._lang_token("zh") in task.dictionary |
|
|
| en_bos = obt._lang_token_index(task.common_dict, "en") |
| assert "en" == task.common_dict[en_bos].strip("_") |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") |
| assert "zh" == task.common_dict[zh_bos].strip("_") |
| zh_sample = mk_sample([zh_bos, 16, 14, 12, 10]) |
|
|
| |
| assert task.get_bos_token_from_sample(zh_sample) == en_bos |
|
|
| def test_backtranslate_sample(self): |
| task, model = self.obt_task(["en", "ro", "zh"]) |
|
|
| en_bos = obt._lang_token_index(task.common_dict, "en") |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") |
| sample = mk_sample([zh_bos, 16, 14, 12, 10]) |
|
|
| task.backtranslate_sample(sample, "zh", "en") |
| target_zh = list(sample["target"][0]) |
| assert target_zh == [16, 14, 12, 10] |
| generated_en = sample["net_input"]["src_tokens"][0] |
| assert generated_en[0] == en_bos |
|
|
| def test_train_dataset(self): |
| data = self.tmp_path("test_train_dataset") |
| mk_dataset(20, 10, data / "en" / "train.bin") |
| mk_dataset(10, 10, data / "zh" / "train.bin") |
| task, model = self.obt_task(["en", "zh"], data) |
| task.load_dataset("train") |
|
|
| en_bos = obt._lang_token_index(task.common_dict, "en") |
| zh_bos = obt._lang_token_index(task.common_dict, "zh") |
|
|
| train = task.datasets["train"] |
| train.ordered_indices() |
| train.prefetch([0, 19]) |
| sample_0 = train[0] |
| sample_19 = train[19] |
| self.assertEqual( |
| set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"} |
| ) |
| for sample in (sample_0, sample_19): |
| self.assertEqual(sample["en-BT"]["source"][0], en_bos) |
| |
| self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos) |
| |
|
|
| for i in range(10): |
| |
| train.prefetch([i, i + 10]) |
| self.assertEqual( |
| list(train[i]["zh-DENOISE"]["source"]), |
| list(train[i + 10]["zh-DENOISE"]["source"]), |
| ) |
| self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos) |
|
|
| |
| self.assertLess( |
| len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"]) |
| ) |
|
|
| def test_valid_dataset(self): |
| data = self.tmp_path("test_valid_dataset") |
| mk_dataset(10, 21, data / "valid.en-zh.en.bin") |
| mk_dataset(10, 21, data / "valid.en-zh.zh.bin") |
|
|
| task, model = self.obt_task(["en", "zh"], data) |
| valid = task.load_dataset("valid") |
| en_bos = obt._lang_token_index(task.common_dict, "en") |
|
|
| assert valid is not None |
| valid.prefetch(range(10)) |
| sample_0 = valid[0] |
| sample_9 = valid[9] |
| self.assertEqual(sample_0["id"], 0) |
| self.assertEqual(sample_9["id"], 9) |
| self.assertEqual(sample_0["source"][0], en_bos) |
| self.assertEqual(sample_9["source"][0], en_bos) |
| |
|
|
| def assertFnMatch(self, fn, values): |
| for x, y in values.items(): |
| fn_x = fn(x) |
| self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}") |
|
|
| def test_piecewise_linear_fn(self): |
| self.assertFnMatch( |
| obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1} |
| ) |
| self.assertFnMatch( |
| obt.PiecewiseLinearFn.from_string("0:1,1000:0"), |
| {0: 1, 500: 0.5, 1000: 0, 2000: 0}, |
| ) |
| self.assertFnMatch( |
| obt.PiecewiseLinearFn.from_string("0:0,1000:1"), |
| {0: 0, 500: 0.5, 1000: 1, 2000: 1}, |
| ) |
| self.assertFnMatch( |
| obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"), |
| {0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0}, |
| ) |
|
|