david
commited on
Commit
·
aca8f40
1
Parent(s):
cd7fb92
fix seg id error
Browse files- tests/test_transcript_analysis.py +68 -0
- transcribe/strategy.py +2 -2
tests/test_transcript_analysis.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import MagicMock, patch
|
| 3 |
+
from transcribe.strategy import TranscriptStabilityAnalyzer, TranscriptChunk, TranscriptResult, SplitMode
|
| 4 |
+
|
| 5 |
+
class TestTranscriptStabilityAnalyzer(unittest.TestCase):
|
| 6 |
+
def setUp(self):
|
| 7 |
+
self.analyzer = TranscriptStabilityAnalyzer()
|
| 8 |
+
|
| 9 |
+
def test_first_chunk_yields_pending_text(self):
|
| 10 |
+
mock_chunk = MagicMock(spec=TranscriptChunk)
|
| 11 |
+
mock_chunk.join.return_value = "Hello world."
|
| 12 |
+
|
| 13 |
+
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=None):
|
| 14 |
+
results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=5.0))
|
| 15 |
+
|
| 16 |
+
self.assertEqual(len(results), 1)
|
| 17 |
+
self.assertIsInstance(results[0], TranscriptResult)
|
| 18 |
+
self.assertIn("Hello", results[0].context)
|
| 19 |
+
|
| 20 |
+
def test_short_buffer_with_high_similarity_and_end_sentence(self):
|
| 21 |
+
curr_chunk = MagicMock(spec=TranscriptChunk)
|
| 22 |
+
curr_first = MagicMock()
|
| 23 |
+
curr_rest = [MagicMock()]
|
| 24 |
+
prev_chunk = MagicMock(spec=TranscriptChunk)
|
| 25 |
+
prev_first = MagicMock()
|
| 26 |
+
|
| 27 |
+
# Mock the items attribute
|
| 28 |
+
curr_chunk.items = [curr_first, curr_rest[0]] # Ensure it is iterable
|
| 29 |
+
curr_chunk.get_split_first_rest.return_value = (curr_first, curr_rest)
|
| 30 |
+
prev_chunk.get_split_first_rest.return_value = (prev_first, [])
|
| 31 |
+
curr_first.compare.return_value = 0.85
|
| 32 |
+
curr_first.is_end_sentence.return_value = True
|
| 33 |
+
curr_first.has_punctuation.return_value = True
|
| 34 |
+
curr_first.join.return_value = "This is a test sentence."
|
| 35 |
+
curr_first.get_buffer_index.return_value = 0
|
| 36 |
+
curr_rest[0].join.return_value = " Continuing..."
|
| 37 |
+
|
| 38 |
+
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=prev_chunk):
|
| 39 |
+
with patch.object(self.analyzer._transcript_history, 'add'):
|
| 40 |
+
results = list(self.analyzer.analysis(" ", curr_chunk, buffer_duration=5.0))
|
| 41 |
+
|
| 42 |
+
self.assertGreaterEqual(len(results), 1)
|
| 43 |
+
self.assertTrue(any(r.is_end_sentence for r in results))
|
| 44 |
+
self.assertTrue(any("test" in r.context for r in results))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_long_buffer_triggers_commit(self):
|
| 48 |
+
chunk1 = MagicMock()
|
| 49 |
+
chunk2 = MagicMock()
|
| 50 |
+
chunk3 = MagicMock()
|
| 51 |
+
|
| 52 |
+
chunk1.join.return_value = "Hello."
|
| 53 |
+
chunk2.join.return_value = "How are"
|
| 54 |
+
chunk3.join.return_value = " you?"
|
| 55 |
+
|
| 56 |
+
mock_chunk = MagicMock(spec=TranscriptChunk)
|
| 57 |
+
mock_chunk.split_by.return_value = [chunk1, chunk2, chunk3]
|
| 58 |
+
mock_chunk.get_buffer_index.return_value = 0
|
| 59 |
+
|
| 60 |
+
with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=MagicMock()):
|
| 61 |
+
with patch.object(self.analyzer._transcript_history, 'add'):
|
| 62 |
+
results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=15.0))
|
| 63 |
+
|
| 64 |
+
self.assertTrue(any(r.is_end_sentence for r in results))
|
| 65 |
+
self.assertTrue(any("Hello" in r.context for r in results))
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
unittest.main()
|
transcribe/strategy.py
CHANGED
|
@@ -281,7 +281,7 @@ class TranscriptStabilityAnalyzer:
|
|
| 281 |
if curr_seg_id > prev_seg_id:
|
| 282 |
# 表示生成了一个新段落 换行
|
| 283 |
yield TranscriptResult(
|
| 284 |
-
seg_id=
|
| 285 |
cut_index=frame_cut_index,
|
| 286 |
context=self._transcript_buffer.latest_paragraph,
|
| 287 |
is_end_sentence=True
|
|
@@ -290,7 +290,7 @@ class TranscriptStabilityAnalyzer:
|
|
| 290 |
# 如果还有挂起的文本
|
| 291 |
if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
|
| 292 |
yield TranscriptResult(
|
| 293 |
-
seg_id=curr_seg_id
|
| 294 |
cut_index=frame_cut_index,
|
| 295 |
context=current_not_commit_text
|
| 296 |
)
|
|
|
|
| 281 |
if curr_seg_id > prev_seg_id:
|
| 282 |
# 表示生成了一个新段落 换行
|
| 283 |
yield TranscriptResult(
|
| 284 |
+
seg_id=curr_seg_id-1,
|
| 285 |
cut_index=frame_cut_index,
|
| 286 |
context=self._transcript_buffer.latest_paragraph,
|
| 287 |
is_end_sentence=True
|
|
|
|
| 290 |
# 如果还有挂起的文本
|
| 291 |
if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
|
| 292 |
yield TranscriptResult(
|
| 293 |
+
seg_id=curr_seg_id,
|
| 294 |
cut_index=frame_cut_index,
|
| 295 |
context=current_not_commit_text
|
| 296 |
)
|