| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import pytest |
| import soundfile as sf |
| import torch |
|
|
| from nemo.collections.asr.parts.utils.manifest_utils import read_manifest |
|
|
|
|
| @pytest.fixture() |
| def set_device(): |
| return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @pytest.fixture() |
| def language_specific_text_example(): |
| return { |
| "en": "Caslon's type is clear and neat, and fairly well designed;", |
| "de": "Ich trinke gerne Kräutertee mit Lavendel.", |
| "es": "Los corazones de pollo son una delicia.", |
| "zh": "双辽境内除东辽河、西辽河等5条河流", |
| } |
|
|
|
|
| @pytest.fixture() |
| def supported_languages(language_specific_text_example): |
| return sorted(language_specific_text_example.keys()) |
|
|
|
|
| @pytest.fixture() |
| def get_language_id_from_pretrained_model_name(supported_languages): |
| def _validate(pretrained_model_name): |
| language_id = pretrained_model_name.split("_")[1] |
| if language_id not in supported_languages: |
| pytest.fail( |
| f"`PretrainedModelInfo.pretrained_model_name={pretrained_model_name}` does not follow the naming " |
| f"convention as `tts_languageID_model_*`, or `languageID` is not supported in {supported_languages}." |
| ) |
| return language_id |
|
|
| return _validate |
|
|
|
|
| @pytest.fixture() |
| def mel_spec_example(set_device): |
| |
| |
| min_val = -11.0 |
| max_val = 0.5 |
| batch_size = 1 |
| n_mel_channels = 80 |
| n_frames = 330 |
| spec = (min_val - max_val) * torch.rand(batch_size, n_mel_channels, n_frames, device=set_device) + max_val |
| return spec |
|
|
|
|
| @pytest.fixture() |
| def audio_text_pair_example_english(test_data_dir, set_device): |
| manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json') |
| data = read_manifest(manifest_path) |
| audio_filepath = data[-1]["audio_filepath"] |
| text_raw = data[-1]["text"] |
|
|
| |
| audio_data, orig_sr = sf.read(audio_filepath) |
| audio = torch.tensor(audio_data, dtype=torch.float, device=set_device).unsqueeze(0) |
| audio_len = torch.tensor(audio_data.shape[0], device=set_device).long().unsqueeze(0) |
|
|
| return audio, audio_len, text_raw |
|
|