daihui.zhang commited on
Commit
93d2288
ยท
1 Parent(s): 05e062b

add export data for test

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. api_model.py +15 -0
  3. config.py +2 -0
  4. transcribe/utils.py +22 -1
  5. transcribe/whisper_llm_serve.py +46 -10
.gitignore CHANGED
@@ -173,3 +173,4 @@ cython_debug/
173
 
174
  .idea/
175
  pywhispercpp/
 
 
173
 
174
  .idea/
175
  pywhispercpp/
176
+ test_data.csv
api_model.py CHANGED
@@ -15,6 +15,21 @@ class TransResult(BaseModel):
15
  populate_by_name = True
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class Message(BaseModel):
 
15
  populate_by_name = True
16
 
17
 
18
+ class DebugResult(BaseModel):
19
+ # trans_pattern: str
20
+ seg_id: int
21
+ transcrible_time: float
22
+ translate_time:float
23
+ context: str = Field(alias="transcribleContent")
24
+ from_: str = Field(alias="from")
25
+ to: str
26
+ tran_content: str = Field(alias="translateContent")
27
+ partial: bool = True
28
+
29
+ class Config:
30
+ populate_by_name = True
31
+
32
+
33
 
34
 
35
  class Message(BaseModel):
config.py CHANGED
@@ -9,6 +9,8 @@ logging.basicConfig(
9
  datefmt="%H:%M:%S"
10
  )
11
 
 
 
12
  logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
13
 
14
 
 
9
  datefmt="%H:%M:%S"
10
  )
11
 
12
+ TEST = True
13
+
14
  logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
15
 
16
 
transcribe/utils.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  import numpy as np
6
  from scipy.io.wavfile import write
7
  import config
8
-
9
  import av
10
  def log_block(key: str, value, unit=''):
11
  if config.DEBUG:
@@ -94,3 +94,24 @@ def resample(file: str, sr: int = 16000):
94
 
95
  def save_to_wave(filename, data:np.ndarray, sample_rate=16000):
96
  write(filename, sample_rate, data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  from scipy.io.wavfile import write
7
  import config
8
+ import csv
9
  import av
10
  def log_block(key: str, value, unit=''):
11
  if config.DEBUG:
 
94
 
95
  def save_to_wave(filename, data:np.ndarray, sample_rate=16000):
96
  write(filename, sample_rate, data)
97
+
98
+
99
+ class TestDataWriter:
100
+ def __init__(self, file_path='test_data.csv'):
101
+ self.file_path = file_path
102
+ self.fieldnames = [
103
+ 'seg_id', 'transcrible_time', 'translate_time',
104
+ 'transcribleContent', 'from', 'to', 'translateContent', 'partial'
105
+ ]
106
+ self._ensure_file_has_header()
107
+
108
+ def _ensure_file_has_header(self):
109
+ if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
110
+ with open(self.file_path, mode='w', newline='') as file:
111
+ writer = csv.DictWriter(file, fieldnames=self.fieldnames)
112
+ writer.writeheader()
113
+
114
+ def write(self, result: 'DebugResult'):
115
+ with open(self.file_path, mode='a', newline='') as file:
116
+ writer = csv.DictWriter(file, fieldnames=self.fieldnames)
117
+ writer.writerow(result.model_dump(by_alias=True))
transcribe/whisper_llm_serve.py CHANGED
@@ -7,13 +7,15 @@ from logging import getLogger
7
  from typing import List, Optional, Iterator, Tuple, Any
8
  import asyncio
9
  import numpy as np
 
10
  # import wordninja
11
- from api_model import TransResult, Message
12
  from .server import ServeClientBase
13
- from .utils import log_block, save_to_wave
14
  from .translatepipes import TranslatePipes
15
  from .strategy import (
16
  TranscriptStabilityAnalyzer, TranscriptToken)
 
17
 
18
  logger = getLogger("TranscriptionService")
19
 
@@ -37,6 +39,7 @@ class WhisperTranscriptionService(ServeClientBase):
37
  self.frames_np = None
38
  self.lock = threading.Lock()
39
  self._frame_queue = queue.Queue()
 
40
 
41
  # ๆ–‡ๆœฌๅˆ†้š”็ฌฆ๏ผŒๆ นๆฎ่ฏญ่จ€่ฎพ็ฝฎ
42
  self.text_separator = self._get_text_separator(language)
@@ -47,10 +50,25 @@ class WhisperTranscriptionService(ServeClientBase):
47
  # ๅฏๅŠจๅค„็†็บฟ็จ‹
48
  self._translate_thread_stop = threading.Event()
49
  self._frame_processing_thread_stop = threading.Event()
 
50
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
51
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
52
 
 
 
 
 
 
 
 
 
53
  # self._c = 0
 
 
 
 
 
 
54
 
55
 
56
  def _start_thread(self, target_function) -> threading.Thread:
@@ -114,7 +132,13 @@ class WhisperTranscriptionService(ServeClientBase):
114
  """ไปŽ้Ÿณ้ข‘็ผ“ๅ†ฒๅŒบไธญ็งป้™คๅทฒๅค„็†็š„้ƒจๅˆ†"""
115
  with self.lock:
116
  if self.frames_np is not None and offset > 0:
 
 
117
  self.frames_np = self.frames_np[offset:]
 
 
 
 
118
 
119
  def _get_audio_for_processing(self) -> Optional[np.ndarray]:
120
  """ๅ‡†ๅค‡็”จไบŽๅค„็†็š„้Ÿณ้ข‘ๅ—"""
@@ -147,11 +171,12 @@ class WhisperTranscriptionService(ServeClientBase):
147
 
148
  result = self._translate_pipe.transcrible(audio_buffer.tobytes(), self.source_language)
149
  segments = result.segments
 
150
  logger.debug(f"๐Ÿ“ Transcrible Segments: {segments} ")
151
  logger.debug(f"๐Ÿ“ Transcrible: {self.text_separator.join(seg.text for seg in segments)} ")
152
  log_block("๐Ÿ“ Transcrible output", f"{self.text_separator.join(seg.text for seg in segments)}", "")
153
- log_block("๐Ÿ“ Transcrible time", f"{(time.perf_counter() - start_time):.3f}", "s")
154
-
155
  return [
156
  TranscriptToken(text=s.text, t0=s.t0, t1=s.t1)
157
  for s in segments
@@ -167,10 +192,10 @@ class WhisperTranscriptionService(ServeClientBase):
167
 
168
  result = self._translate_pipe.translate(text, self.source_language, self.target_language)
169
  translated_text = result.translate_content
170
-
171
- log_block("๐Ÿง Translation time ", f"{(time.perf_counter() - start_time):.3f}", "s")
172
  log_block("๐Ÿง Translation out ", f"{translated_text}")
173
-
174
  return translated_text
175
 
176
  def _translate_text_large(self, text: str) -> str:
@@ -183,10 +208,10 @@ class WhisperTranscriptionService(ServeClientBase):
183
 
184
  result = self._translate_pipe.translate_large(text, self.source_language, self.target_language)
185
  translated_text = result.translate_content
186
-
187
- log_block("Translation large model time ", f"{(time.perf_counter() - start_time):.3f}", "s")
188
  log_block("Translation large model output", f"{translated_text}")
189
-
190
  return translated_text
191
 
192
 
@@ -253,6 +278,17 @@ class WhisperTranscriptionService(ServeClientBase):
253
  )
254
  current_time = time.perf_counter()
255
  time_diff = current_time - start_time
 
 
 
 
 
 
 
 
 
 
 
256
  log_block("๐Ÿšฆ Traffic times diff", round(time_diff, 2), 's')
257
 
258
 
 
7
  from typing import List, Optional, Iterator, Tuple, Any
8
  import asyncio
9
  import numpy as np
10
+ import config
11
  # import wordninja
12
+ from api_model import TransResult, Message, DebugResult
13
  from .server import ServeClientBase
14
+ from .utils import log_block, save_to_wave, TestDataWriter
15
  from .translatepipes import TranslatePipes
16
  from .strategy import (
17
  TranscriptStabilityAnalyzer, TranscriptToken)
18
+ import csv
19
 
20
  logger = getLogger("TranscriptionService")
21
 
 
39
  self.frames_np = None
40
  self.lock = threading.Lock()
41
  self._frame_queue = queue.Queue()
42
+
43
 
44
  # ๆ–‡ๆœฌๅˆ†้š”็ฌฆ๏ผŒๆ นๆฎ่ฏญ่จ€่ฎพ็ฝฎ
45
  self.text_separator = self._get_text_separator(language)
 
50
  # ๅฏๅŠจๅค„็†็บฟ็จ‹
51
  self._translate_thread_stop = threading.Event()
52
  self._frame_processing_thread_stop = threading.Event()
53
+
54
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
55
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
56
 
57
+ # for test
58
+ self._transcrible_time_cost = 0.
59
+ self._translate_time_cost = 0.
60
+ if config.TEST:
61
+ self._test_task_event = threading.Event()
62
+ self._test_queue = queue.Queue()
63
+ self._test_thread = self._start_thread(self.test_data_loop)
64
+
65
  # self._c = 0
66
+
67
+ def test_data_loop(self):
68
+ writer = TestDataWriter()
69
+ while not self._test_task_event.is_set():
70
+ test_data = self._test_queue.get()
71
+ writer.write(test_data) # Save test_data to CSV
72
 
73
 
74
  def _start_thread(self, target_function) -> threading.Thread:
 
132
  """ไปŽ้Ÿณ้ข‘็ผ“ๅ†ฒๅŒบไธญ็งป้™คๅทฒๅค„็†็š„้ƒจๅˆ†"""
133
  with self.lock:
134
  if self.frames_np is not None and offset > 0:
135
+ # self._c += 1
136
+ # before = self.frames_np.copy()
137
  self.frames_np = self.frames_np[offset:]
138
+ # after = self.frames_np.copy()
139
+ # save_to_wave(f"./tests/{self._c}_before_cut.wav", before)
140
+ # save_to_wave(f"./tests/{self._c}_after_cut.wav", after)
141
+
142
 
143
  def _get_audio_for_processing(self) -> Optional[np.ndarray]:
144
  """ๅ‡†ๅค‡็”จไบŽๅค„็†็š„้Ÿณ้ข‘ๅ—"""
 
171
 
172
  result = self._translate_pipe.transcrible(audio_buffer.tobytes(), self.source_language)
173
  segments = result.segments
174
+ time_diff = (time.perf_counter() - start_time)
175
  logger.debug(f"๐Ÿ“ Transcrible Segments: {segments} ")
176
  logger.debug(f"๐Ÿ“ Transcrible: {self.text_separator.join(seg.text for seg in segments)} ")
177
  log_block("๐Ÿ“ Transcrible output", f"{self.text_separator.join(seg.text for seg in segments)}", "")
178
+ log_block("๐Ÿ“ Transcrible time", f"{time_diff:.3f}", "s")
179
+ self._transcrible_time_cost = round(time_diff, 3)
180
  return [
181
  TranscriptToken(text=s.text, t0=s.t0, t1=s.t1)
182
  for s in segments
 
192
 
193
  result = self._translate_pipe.translate(text, self.source_language, self.target_language)
194
  translated_text = result.translate_content
195
+ time_diff = (time.perf_counter() - start_time)
196
+ log_block("๐Ÿง Translation time ", f"{time_diff:.3f}", "s")
197
  log_block("๐Ÿง Translation out ", f"{translated_text}")
198
+ self._translate_time_cost = round(time_diff, 3)
199
  return translated_text
200
 
201
  def _translate_text_large(self, text: str) -> str:
 
208
 
209
  result = self._translate_pipe.translate_large(text, self.source_language, self.target_language)
210
  translated_text = result.translate_content
211
+ time_diff = (time.perf_counter() - start_time)
212
+ log_block("Translation large model time ", f"{time_diff:.3f}", "s")
213
  log_block("Translation large model output", f"{translated_text}")
214
+ self._translate_time_cost = round(time_diff, 3)
215
  return translated_text
216
 
217
 
 
278
  )
279
  current_time = time.perf_counter()
280
  time_diff = current_time - start_time
281
+ if config.TEST:
282
+ self._test_queue.put(DebugResult(
283
+ seg_id=ana_result.seg_id,
284
+ transcrible_time=self._transcrible_time_cost,
285
+ translate_time=self._translate_time_cost,
286
+ context=ana_result.context,
287
+ from_=self.source_language,
288
+ to=self.target_language,
289
+ tran_content=translated_context,
290
+ partial=ana_result.partial()
291
+ ))
292
  log_block("๐Ÿšฆ Traffic times diff", round(time_diff, 2), 's')
293
 
294