Spaces:
Runtime error
Runtime error
| import unittest | |
| import torch | |
| from training.preprocess.wav2vec_aligner import Wav2VecAligner | |
| class TestWav2VecAligner(unittest.TestCase): | |
| def setUp(self): | |
| self.model = Wav2VecAligner() | |
| self.text = "I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT" | |
| self.wav_path = "./mocks/audio_example.wav" | |
| def test_load_audio(self): | |
| _, sample_rate = self.model.load_audio(self.wav_path) | |
| self.assertEqual(sample_rate, 16_000) | |
| with self.assertRaises(FileNotFoundError): | |
| self.model.load_audio("./nonexistent/path.wav") | |
| def test_encode(self): | |
| tokens = self.model.encode(self.text) | |
| torch.testing.assert_close( | |
| tokens, | |
| torch.tensor( | |
| [ | |
| [ | |
| 10, | |
| 4, | |
| 11, | |
| 7, | |
| 14, | |
| 4, | |
| 6, | |
| 11, | |
| 7, | |
| 6, | |
| 4, | |
| 19, | |
| 16, | |
| 13, | |
| 10, | |
| 8, | |
| 12, | |
| 10, | |
| 6, | |
| 22, | |
| 4, | |
| 24, | |
| 5, | |
| 12, | |
| 10, | |
| 14, | |
| 5, | |
| 4, | |
| 17, | |
| 5, | |
| 4, | |
| 7, | |
| 6, | |
| 4, | |
| 6, | |
| 11, | |
| 10, | |
| 12, | |
| 4, | |
| 17, | |
| 8, | |
| 17, | |
| 5, | |
| 9, | |
| 6, | |
| ], | |
| ], | |
| ), | |
| ) | |
| def test_decode(self): | |
| transcript = self.model.decode( | |
| [ | |
| [ | |
| 10, | |
| 4, | |
| 11, | |
| 7, | |
| 14, | |
| 4, | |
| 6, | |
| 11, | |
| 7, | |
| 6, | |
| 4, | |
| 19, | |
| 16, | |
| 13, | |
| 10, | |
| 8, | |
| 12, | |
| 10, | |
| 6, | |
| 22, | |
| 4, | |
| 24, | |
| 5, | |
| 12, | |
| 10, | |
| 14, | |
| 5, | |
| 4, | |
| 17, | |
| 5, | |
| 4, | |
| 7, | |
| 6, | |
| 4, | |
| 6, | |
| 11, | |
| 10, | |
| 12, | |
| 4, | |
| 17, | |
| 8, | |
| 17, | |
| 5, | |
| 9, | |
| 6, | |
| ], | |
| ], | |
| ) | |
| self.assertEqual(transcript, self.text) | |
| def test_align_single_sample(self): | |
| audio_input, _ = self.model.load_audio(self.wav_path) | |
| emissions, tokens, transcript = self.model.align_single_sample( | |
| audio_input, self.text, | |
| ) | |
| self.assertEqual(emissions.shape, torch.Size([169, 32])) | |
| self.assertEqual( | |
| len(tokens), | |
| 47, | |
| ) | |
| self.assertEqual(transcript, "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|") | |
| def test_get_trellis(self): | |
| audio_input, _ = self.model.load_audio(self.wav_path) | |
| emissions, tokens, _ = self.model.align_single_sample(audio_input, self.text) | |
| trellis = self.model.get_trellis(emissions, tokens) | |
| self.assertEqual(emissions.shape, torch.Size([169, 32])) | |
| self.assertEqual(len(tokens), 47) | |
| # Add assertions here based on the expected behavior of get_trellis | |
| self.assertIsInstance(trellis, torch.Tensor) | |
| self.assertEqual(trellis.shape, torch.Size([169, 47])) | |
| def test_backtrack(self): | |
| audio_input, _ = self.model.load_audio(self.wav_path) | |
| emissions, tokens, _ = self.model.align_single_sample(audio_input, self.text) | |
| trellis = self.model.get_trellis(emissions, tokens) | |
| path = self.model.backtrack(trellis, emissions, tokens) | |
| # Add assertions here based on the expected behavior of backtrack | |
| self.assertIsInstance(path, list) | |
| self.assertEqual(len(path), 169) | |
| def test_merge_repeats(self): | |
| audio_input, _ = self.model.load_audio(self.wav_path) | |
| emissions, tokens, transcript = self.model.align_single_sample( | |
| audio_input, self.text, | |
| ) | |
| trellis = self.model.get_trellis(emissions, tokens) | |
| path = self.model.backtrack(trellis, emissions, tokens) | |
| merged_path = self.model.merge_repeats(path, transcript) | |
| # Add assertions here based on the expected behavior of merge_repeats | |
| self.assertIsInstance(merged_path, list) | |
| self.assertEqual(len(merged_path), 47) | |
| def test_merge_words(self): | |
| audio_input, _ = self.model.load_audio(self.wav_path) | |
| emissions, tokens, transcript = self.model.align_single_sample( | |
| audio_input, self.text, | |
| ) | |
| trellis = self.model.get_trellis(emissions, tokens) | |
| path = self.model.backtrack(trellis, emissions, tokens) | |
| merged_path = self.model.merge_repeats(path, transcript) | |
| merged_words = self.model.merge_words(merged_path) | |
| # Add assertions here based on the expected behavior of merge_words | |
| self.assertIsInstance(merged_words, list) | |
| self.assertEqual(len(merged_words), 9) | |
| def test_forward(self): | |
| result = self.model(self.wav_path, self.text) | |
| # self.assertEqual(result, expected_result) | |
| self.assertEqual(len(result), 9) | |
| def test_save_segments(self): | |
| # self.model.save_segments(self.wav_path, self.text, "./mocks/wav2vec_aligner/audio") | |
| self.assertEqual(True, True) | |
| if __name__ == "__main__": | |
| unittest.main() | |