| |
| |
| |
| |
|
|
| import unittest |
| from collections import namedtuple |
| from pathlib import Path |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| import fairseq |
| from fairseq import utils |
| from fairseq.checkpoint_utils import load_model_ensemble_and_task |
| from fairseq.scoring.bleu import SacrebleuScorer |
| from fairseq.tasks import import_tasks |
| from tests.speech import S3_BASE_URL, TestFairseqSpeech |
|
|
|
|
| @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
| class TestLibrispeechDualInputWavTransformer(TestFairseqSpeech): |
| def setUp(self): |
| dataset_id = "librispeech_wvtrasnformer" |
| base_url = "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned" |
| data_filenames = [ |
| "checkpoint_ave_10.pt", |
| "spm.model", |
| "src_dict.txt", |
| "tgt_dict.txt", |
| "config.yaml", |
| ] |
| self._set_up( |
| dataset_id, |
| "s2t", |
| [ |
| "librispeech_flac_test-other.tsv", |
| "librispeech_flac_test-other.zip", |
| ], |
| ) |
| for filename in data_filenames: |
| self.download(base_url, self.root, filename) |
|
|
| def import_user_module(self): |
| user_dir = ( |
| Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text" |
| ) |
| Arg = namedtuple("Arg", ["user_dir"]) |
| arg = Arg(user_dir.__str__()) |
| utils.import_user_module(arg) |
|
|
| @torch.no_grad() |
| def test_librispeech_dualinput_wav_transformer_checkpoint(self): |
| self.import_user_module() |
| checkpoint_filename = "checkpoint_ave_10.pt" |
| arg_overrides = { |
| "config_yaml": "config.yaml", |
| "load_pretrained_speech_text_encoder": "", |
| "load_pretrained_speech_text_decoder": "", |
| "beam": 10, |
| "nbest": 1, |
| "lenpen": 1.0, |
| "load_speech_only": True, |
| } |
| self.base_test( |
| checkpoint_filename, |
| 4.6, |
| dataset="librispeech_flac_test-other", |
| max_tokens=800000, |
| max_positions=(800000, 1024), |
| arg_overrides=arg_overrides, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|