david commited on
Commit
aca8f40
·
1 Parent(s): cd7fb92

fix seg id error

Browse files
tests/test_transcript_analysis.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ from transcribe.strategy import TranscriptStabilityAnalyzer, TranscriptChunk, TranscriptResult, SplitMode
4
+
5
+ class TestTranscriptStabilityAnalyzer(unittest.TestCase):
6
+ def setUp(self):
7
+ self.analyzer = TranscriptStabilityAnalyzer()
8
+
9
+ def test_first_chunk_yields_pending_text(self):
10
+ mock_chunk = MagicMock(spec=TranscriptChunk)
11
+ mock_chunk.join.return_value = "Hello world."
12
+
13
+ with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=None):
14
+ results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=5.0))
15
+
16
+ self.assertEqual(len(results), 1)
17
+ self.assertIsInstance(results[0], TranscriptResult)
18
+ self.assertIn("Hello", results[0].context)
19
+
20
+ def test_short_buffer_with_high_similarity_and_end_sentence(self):
21
+ curr_chunk = MagicMock(spec=TranscriptChunk)
22
+ curr_first = MagicMock()
23
+ curr_rest = [MagicMock()]
24
+ prev_chunk = MagicMock(spec=TranscriptChunk)
25
+ prev_first = MagicMock()
26
+
27
+ # Mock the items attribute
28
+ curr_chunk.items = [curr_first, curr_rest[0]] # Ensure it is iterable
29
+ curr_chunk.get_split_first_rest.return_value = (curr_first, curr_rest)
30
+ prev_chunk.get_split_first_rest.return_value = (prev_first, [])
31
+ curr_first.compare.return_value = 0.85
32
+ curr_first.is_end_sentence.return_value = True
33
+ curr_first.has_punctuation.return_value = True
34
+ curr_first.join.return_value = "This is a test sentence."
35
+ curr_first.get_buffer_index.return_value = 0
36
+ curr_rest[0].join.return_value = " Continuing..."
37
+
38
+ with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=prev_chunk):
39
+ with patch.object(self.analyzer._transcript_history, 'add'):
40
+ results = list(self.analyzer.analysis(" ", curr_chunk, buffer_duration=5.0))
41
+
42
+ self.assertGreaterEqual(len(results), 1)
43
+ self.assertTrue(any(r.is_end_sentence for r in results))
44
+ self.assertTrue(any("test" in r.context for r in results))
45
+
46
+
47
+ def test_long_buffer_triggers_commit(self):
48
+ chunk1 = MagicMock()
49
+ chunk2 = MagicMock()
50
+ chunk3 = MagicMock()
51
+
52
+ chunk1.join.return_value = "Hello."
53
+ chunk2.join.return_value = "How are"
54
+ chunk3.join.return_value = " you?"
55
+
56
+ mock_chunk = MagicMock(spec=TranscriptChunk)
57
+ mock_chunk.split_by.return_value = [chunk1, chunk2, chunk3]
58
+ mock_chunk.get_buffer_index.return_value = 0
59
+
60
+ with patch.object(self.analyzer._transcript_history, 'previous_chunk', return_value=MagicMock()):
61
+ with patch.object(self.analyzer._transcript_history, 'add'):
62
+ results = list(self.analyzer.analysis(" ", mock_chunk, buffer_duration=15.0))
63
+
64
+ self.assertTrue(any(r.is_end_sentence for r in results))
65
+ self.assertTrue(any("Hello" in r.context for r in results))
66
+
67
+ if __name__ == '__main__':
68
+ unittest.main()
transcribe/strategy.py CHANGED
@@ -281,7 +281,7 @@ class TranscriptStabilityAnalyzer:
281
  if curr_seg_id > prev_seg_id:
282
  # 表示生成了一个新段落 换行
283
  yield TranscriptResult(
284
- seg_id=prev_seg_id,
285
  cut_index=frame_cut_index,
286
  context=self._transcript_buffer.latest_paragraph,
287
  is_end_sentence=True
@@ -290,7 +290,7 @@ class TranscriptStabilityAnalyzer:
290
  # 如果还有挂起的文本
291
  if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
292
  yield TranscriptResult(
293
- seg_id=curr_seg_id+1,
294
  cut_index=frame_cut_index,
295
  context=current_not_commit_text
296
  )
 
281
  if curr_seg_id > prev_seg_id:
282
  # 表示生成了一个新段落 换行
283
  yield TranscriptResult(
284
+ seg_id=curr_seg_id-1,
285
  cut_index=frame_cut_index,
286
  context=self._transcript_buffer.latest_paragraph,
287
  is_end_sentence=True
 
290
  # 如果还有挂起的文本
291
  if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
292
  yield TranscriptResult(
293
+ seg_id=curr_seg_id,
294
  cut_index=frame_cut_index,
295
  context=current_not_commit_text
296
  )