| | import unittest |
| | from transcribe.strategy import TranscriptChunk, TranscriptToken, SplitMode |
| |
|
| | class TestTranscriptChunk(unittest.TestCase): |
| |
|
| | def setUp(self): |
| | self.tokens = [ |
| | TranscriptToken(text="Hello", t0=0, t1=100), |
| | TranscriptToken(text=",", t0=100, t1=200), |
| | TranscriptToken(text="world", t0=200, t1=300), |
| | TranscriptToken(text=".", t0=300, t1=400), |
| | ] |
| | self.chunk = TranscriptChunk(items=self.tokens, separator=" ") |
| |
|
| | def test_split_by_punctuation(self): |
| | chunks = self.chunk.split_by(SplitMode.PUNCTUATION) |
| | self.assertEqual(len(chunks), 3) |
| | self.assertEqual(chunks[0].join(), "Hello ,") |
| | self.assertEqual(chunks[1].join(), "world .") |
| | self.assertEqual(chunks[2].join(), "") |
| |
|
| | def test_get_split_first_rest(self): |
| | first, rest = self.chunk.get_split_first_rest(SplitMode.PUNCTUATION) |
| | self.assertEqual(first.join(), "Hello ,") |
| | self.assertEqual(len(rest), 2) |
| | self.assertEqual(rest[0].join(), "world .") |
| | self.assertEqual(rest[1].join(), "") |
| |
|
| | def test_punctuation_numbers(self): |
| | self.assertEqual(self.chunk.puncation_numbers(), 2) |
| |
|
| | def test_length(self): |
| | self.assertEqual(self.chunk.length(), 4) |
| |
|
| | def test_join(self): |
| | self.assertEqual(self.chunk.join(), "Hello , world .") |
| |
|
| | def test_compare(self): |
| | other_chunk = TranscriptChunk(items=[ |
| | TranscriptToken(text="Hello", t0=0, t1=100), |
| | TranscriptToken(text="!", t0=100, t1=200), |
| | ], separator=" ") |
| | similarity = self.chunk.compare(other_chunk) |
| | self.assertTrue(0 < similarity < 1) |
| |
|
| | def test_has_punctuation(self): |
| | self.assertTrue(self.chunk.has_punctuation()) |
| |
|
| | def test_get_buffer_index(self): |
| | |
| | self.assertEqual(self.chunk.get_buffer_index(), 64000) |
| |
|
| | def test_is_end_sentence(self): |
| | self.assertTrue(self.chunk.is_end_sentence()) |
| |
|
| | if __name__ == '__main__': |
| | unittest.main() |
| |
|