| | import os |
| | import shutil |
| | import tempfile |
| | import unittest |
| |
|
| | from fairseq import options |
| | from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
| | from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored |
| | from .utils import create_dummy_data, preprocess_lm_data, train_language_model |
| |
|
| |
|
| | def make_lm_config( |
| | data_dir=None, |
| | extra_flags=None, |
| | task="language_modeling", |
| | arch="transformer_lm_gpt2_tiny", |
| | ): |
| | task_args = [task] |
| | if data_dir is not None: |
| | task_args += [data_dir] |
| | train_parser = options.get_training_parser() |
| | train_args = options.parse_args_and_arch( |
| | train_parser, |
| | [ |
| | "--task", |
| | *task_args, |
| | "--arch", |
| | arch, |
| | "--optimizer", |
| | "adam", |
| | "--lr", |
| | "0.0001", |
| | "--max-tokens", |
| | "500", |
| | "--tokens-per-sample", |
| | "500", |
| | "--save-dir", |
| | data_dir, |
| | "--max-epoch", |
| | "1", |
| | ] |
| | + (extra_flags or []), |
| | ) |
| | cfg = convert_namespace_to_omegaconf(train_args) |
| | return cfg |
| |
|
| |
|
| | def write_empty_file(path): |
| | with open(path, "w"): |
| | pass |
| | assert os.path.exists(path) |
| |
|
| |
|
| | class TestValidSubsetsErrors(unittest.TestCase): |
| | """Test various filesystem, clarg combinations and ensure that error raising happens as expected""" |
| |
|
| | def _test_case(self, paths, extra_flags): |
| | with tempfile.TemporaryDirectory() as data_dir: |
| | [ |
| | write_empty_file(os.path.join(data_dir, f"{p}.bin")) |
| | for p in paths + ["train"] |
| | ] |
| | cfg = make_lm_config(data_dir, extra_flags=extra_flags) |
| | raise_if_valid_subsets_unintentionally_ignored(cfg) |
| |
|
| | def test_default_raises(self): |
| | with self.assertRaises(ValueError): |
| | self._test_case(["valid", "valid1"], []) |
| | with self.assertRaises(ValueError): |
| | self._test_case( |
| | ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] |
| | ) |
| |
|
| | def partially_specified_valid_subsets(self): |
| | with self.assertRaises(ValueError): |
| | self._test_case( |
| | ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"] |
| | ) |
| | |
| | self._test_case( |
| | ["valid", "valid1", "valid2"], |
| | ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"], |
| | ) |
| |
|
| | def test_legal_configs(self): |
| | self._test_case(["valid"], []) |
| | self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"]) |
| | self._test_case(["valid", "valid1"], ["--combine-val"]) |
| | self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"]) |
| | self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"]) |
| | self._test_case( |
| | ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"] |
| | ) |
| | self._test_case( |
| | ["valid1"], ["--valid-subset", "valid1"] |
| | ) |
| |
|
| | def test_disable_validation(self): |
| | self._test_case([], ["--disable-validation"]) |
| | self._test_case(["valid", "valid1"], ["--disable-validation"]) |
| |
|
| | def test_dummy_task(self): |
| | cfg = make_lm_config(task="dummy_lm") |
| | raise_if_valid_subsets_unintentionally_ignored(cfg) |
| |
|
| | def test_masked_dummy_task(self): |
| | cfg = make_lm_config(task="dummy_masked_lm") |
| | raise_if_valid_subsets_unintentionally_ignored(cfg) |
| |
|
| |
|
| | class TestCombineValidSubsets(unittest.TestCase): |
| | def _train(self, extra_flags): |
| | with self.assertLogs() as logs: |
| | with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: |
| | create_dummy_data(data_dir, num_examples=20) |
| | preprocess_lm_data(data_dir) |
| |
|
| | shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin") |
| | shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx") |
| | train_language_model( |
| | data_dir, |
| | "transformer_lm", |
| | ["--max-update", "0", "--log-format", "json"] + extra_flags, |
| | run_validation=False, |
| | ) |
| | return [x.message for x in logs.records] |
| |
|
| | def test_combined(self): |
| | flags = ["--combine-valid-subsets"] |
| | logs = self._train(flags) |
| | assert any(["valid1" in x for x in logs]) |
| | assert not any(["valid1_ppl" in x for x in logs]) |
| |
|
| | def test_subsets(self): |
| | flags = ["--valid-subset", "valid,valid1"] |
| | logs = self._train(flags) |
| | assert any(["valid_ppl" in x for x in logs]) |
| | assert any(["valid1_ppl" in x for x in logs]) |
| |
|