Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import unittest | |
| from torch import Tensor | |
| import torchaudio | |
| from voicefixer import Vocoder | |
| from training.datasets.hifi_libri_dataset import HifiLibriDataset, HifiLibriItem | |
| class TestHifiLibriDataset(unittest.TestCase): | |
| def setUp(self): | |
| self.cache_dir = "datasets_cache" | |
| self.dataset = HifiLibriDataset(cache_dir=self.cache_dir, cache=True) | |
| self.vocoder_vf = Vocoder(44100) | |
| def test_init(self): | |
| self.assertEqual(len(self.dataset.cutset), 129751) | |
| def test_get_cache_subdir_path(self): | |
| idx = 1234 | |
| expected_path = Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" | |
| self.assertEqual(self.dataset.get_cache_subdir_path(idx), expected_path) | |
| def test_get_cache_file_path(self): | |
| idx = 1234 | |
| expected_path = ( | |
| Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" / f"{idx}.pt" | |
| ) | |
| self.assertEqual(self.dataset.get_cache_file_path(idx), expected_path) | |
| def test_getitem(self): | |
| # Take the hifi items from the beginning of the dataset | |
| item = self.dataset[0] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "hifitts") | |
| # Convert mel spectrogram to waveform and save it to a file | |
| # NOTE: Vocoder expects the mel spectrogram to be prepared in a specific way | |
| # wav = self.vocoder_vf.forward(item.mel.permute((1, 0)).unsqueeze(0)) | |
| # wav_path = Path(f"results/{item.id}.wav") | |
| # torchaudio.save(str(wav_path), wav, 44100) | |
| # Check that the cache file is created | |
| cache_file = self.dataset.get_cache_file_path(0) | |
| self.assertTrue(cache_file.exists()) | |
| # Take the same id again to check if the cache is used | |
| item = self.dataset[0] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "hifitts") | |
| item = self.dataset[10] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "hifitts") | |
| # Check that the cache file is created | |
| cache_file = self.dataset.get_cache_file_path(10) | |
| self.assertTrue(cache_file.exists()) | |
| item = self.dataset[20] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "hifitts") | |
| # Take the libri items from the end of the dataset | |
| item = self.dataset[len(self.dataset) - 20] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "libritts") | |
| # Check that the cache file is created | |
| cache_file = self.dataset.get_cache_file_path(len(self.dataset) - 20) | |
| self.assertTrue(cache_file.exists()) | |
| item = self.dataset[len(self.dataset) - 10] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "libritts") | |
| item = self.dataset[len(self.dataset) - 5] | |
| self.assertIsInstance(item, HifiLibriItem) | |
| self.assertEqual(item.dataset_type, "libritts") | |
| def test_collate_fn(self): | |
| data = [self.dataset[0] for _ in range(10)] | |
| collated = self.dataset.collate_fn(data) | |
| self.assertIsInstance(collated, list) | |
| self.assertIsInstance(collated[0], list) # ids | |
| self.assertIsInstance(collated[1], list) # raw_texts | |
| self.assertIsInstance(collated[2], Tensor) # speakers | |
| self.assertIsInstance(collated[3], Tensor) # texts | |
| self.assertIsInstance(collated[4], Tensor) # src_lens | |
| self.assertIsInstance(collated[5], Tensor) # mels | |
| self.assertIsInstance(collated[6], Tensor) # pitches | |
| self.assertIsInstance(collated[7], list) # pitches_stat | |
| self.assertIsInstance(collated[8], Tensor) # mel_lens | |
| self.assertIsInstance(collated[9], Tensor) # langs | |
| self.assertIsInstance(collated[10], Tensor) # attn_priors | |
| self.assertIsInstance(collated[11], Tensor) # wavs | |
| self.assertIsInstance(collated[12], Tensor) # energy | |
| def test_include_libri(self): | |
| dataset_with_libri = HifiLibriDataset( | |
| cache_dir="datasets_cache", | |
| include_libri=True, | |
| ) | |
| dataset_without_libri = HifiLibriDataset( | |
| cache_dir="datasets_cache", | |
| include_libri=False, | |
| ) | |
| # Check that the dataset with LibriTTS is larger than the dataset without LibriTTS | |
| self.assertTrue(len(dataset_with_libri) > len(dataset_without_libri)) | |
| # Check that the dataset with LibriTTS includes items of type 'libritts' | |
| libri_item = dataset_with_libri[len(dataset_with_libri) - 10] | |
| self.assertIsInstance(libri_item, HifiLibriItem) | |
| self.assertEqual(libri_item.dataset_type, "libritts") | |
| # Check that the dataset without LibriTTS does not include items of type 'libritts' | |
| hifi_item = dataset_without_libri[len(dataset_without_libri) - 10] | |
| self.assertIsInstance(hifi_item, HifiLibriItem) | |
| self.assertEqual(hifi_item.dataset_type, "hifitts") | |
| def test_dur_filter(self): | |
| # Test with a duration of 0.2 | |
| self.assertFalse(self.dataset.dur_filter(0.2)) | |
| # Test with a duration of 1.0 | |
| self.assertTrue(self.dataset.dur_filter(1.0)) | |
| # Test with a duration of 2.0 | |
| self.assertTrue(self.dataset.dur_filter(2.0)) | |
| # Test with a duration of 30.0 | |
| self.assertFalse(self.dataset.dur_filter(30.0)) | |
| if __name__ == "__main__": | |
| unittest.main() | |