Spaces:
Runtime error
Runtime error
| import unittest | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from training.datasets import LibriTTSDatasetVocoder | |
| class TestLibriTTSDatasetAcoustic(unittest.TestCase): | |
| def setUp(self): | |
| self.batch_size = 2 | |
| self.lang = "en" | |
| self.download = False | |
| self.dataset = LibriTTSDatasetVocoder( | |
| root="datasets_cache/LIBRITTS", | |
| batch_size=self.batch_size, | |
| download=self.download, | |
| ) | |
| def test_len(self): | |
| self.assertEqual(len(self.dataset), 33236) | |
| def test_getitem(self): | |
| sample = self.dataset[0] | |
| self.assertEqual(sample["mel"].shape, torch.Size([100, 64])) | |
| self.assertEqual(sample["audio"].shape, torch.Size([16384])) | |
| self.assertEqual(sample["speaker_id"], 1034) | |
| def test_collate_fn(self): | |
| data = [ | |
| self.dataset[0], | |
| self.dataset[2], | |
| ] | |
| # Call the collate_fn method | |
| result = self.dataset.collate_fn(data) | |
| # Check the output | |
| self.assertEqual(len(result), 4) | |
| # Check that all the batches are the same size | |
| for batch in result: | |
| self.assertEqual(len(batch), self.batch_size) | |
| def test_dataloader(self): | |
| # Create a DataLoader from the dataset | |
| dataloader = DataLoader( | |
| self.dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| collate_fn=self.dataset.collate_fn, | |
| ) | |
| iter_dataloader = iter(dataloader) | |
| # Iterate over the DataLoader and check the output | |
| for _, items in enumerate([next(iter_dataloader), next(iter_dataloader)]): | |
| # Check the batch size | |
| self.assertEqual(len(items), 4) | |
| for it in items: | |
| self.assertEqual(len(it), self.batch_size) | |
| if __name__ == "__main__": | |
| unittest.main() | |