daihui.zhang commited on
Commit
7dc6a6f
·
1 Parent(s): c556c3a

fix bug of count word length

Browse files
transcribe/pipelines/pipe_vad.py CHANGED
@@ -41,7 +41,7 @@ class VadPipe(BasePipe):
41
  def process(self, in_data: MetaItem) -> MetaItem:
42
  source_audio = in_data.source_audio
43
  source_audio = np.frombuffer(source_audio, dtype=np.float32)
44
-
45
  send_audio = b""
46
  speech_timestamps = get_speech_timestamps(torch.Tensor(source_audio), self.model.silero_vad, sampling_rate=16000)
47
 
 
41
  def process(self, in_data: MetaItem) -> MetaItem:
42
  source_audio = in_data.source_audio
43
  source_audio = np.frombuffer(source_audio, dtype=np.float32)
44
+ # source_audio = self.reduce_noise(source_audio)
45
  send_audio = b""
46
  speech_timestamps = get_speech_timestamps(torch.Tensor(source_audio), self.model.silero_vad, sampling_rate=16000)
47
 
transcribe/strategy.py CHANGED
@@ -149,11 +149,12 @@ class TranscriptBuffer:
149
 
150
  """
151
 
152
- def __init__(self):
153
  self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
154
  self._sentences: List[str] = [] # 当前段落中的短句
155
  self._buffer: str = "" # 当前缓冲中的文本
156
  self._current_seg_id: int = 0
 
157
 
158
  def get_seg_id(self) -> int:
159
  return self._current_seg_id
@@ -189,7 +190,8 @@ class TranscriptBuffer:
189
  if is_end_sentence:
190
  self.update_pending_text(stable_string)
191
  self.commit_line()
192
- current_text_len = len(self.current_not_commit_text)
 
193
  self.update_pending_text(remaining_string)
194
  if current_text_len >=20:
195
  self.commit_paragraph()
@@ -224,16 +226,18 @@ class TranscriptBuffer:
224
 
225
 
226
  class TranscriptStabilityAnalyzer:
227
- def __init__(self) -> None:
228
- self._transcript_buffer = TranscriptBuffer()
229
  self._transcript_history = TranscriptHistory()
 
 
230
 
231
  def merge_chunks(self, chunks: List[TranscriptChunk])->str:
232
  return "".join(r.join() for r in chunks)
233
 
234
 
235
- def analysis(self, separator, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
236
- current = TranscriptChunk(items=current, separator=separator)
237
  self._transcript_history.add(current)
238
 
239
  prev = self._transcript_history.previous_chunk()
 
149
 
150
  """
151
 
152
+ def __init__(self, separator):
153
  self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
154
  self._sentences: List[str] = [] # 当前段落中的短句
155
  self._buffer: str = "" # 当前缓冲中的文本
156
  self._current_seg_id: int = 0
157
+ self._separator = separator
158
 
159
  def get_seg_id(self) -> int:
160
  return self._current_seg_id
 
190
  if is_end_sentence:
191
  self.update_pending_text(stable_string)
192
  self.commit_line()
193
+ current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
194
+ # current_text_len = len(self.current_not_commit_text.split(self._separator))
195
  self.update_pending_text(remaining_string)
196
  if current_text_len >=20:
197
  self.commit_paragraph()
 
226
 
227
 
228
  class TranscriptStabilityAnalyzer:
229
+ def __init__(self, separator) -> None:
230
+ self._transcript_buffer = TranscriptBuffer(separator=separator)
231
  self._transcript_history = TranscriptHistory()
232
+ self._separator = separator
233
+ logger.debug(f"Current separator: {self._separator}")
234
 
235
  def merge_chunks(self, chunks: List[TranscriptChunk])->str:
236
  return "".join(r.join() for r in chunks)
237
 
238
 
239
+ def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
240
+ current = TranscriptChunk(items=current, separator=self._separator)
241
  self._transcript_history.add(current)
242
 
243
  prev = self._transcript_history.previous_chunk()
transcribe/whisper_llm_serve.py CHANGED
@@ -29,7 +29,7 @@ class WhisperTranscriptionService(ServeClientBase):
29
  self.target_language = dst_lang # 目标翻译语言
30
 
31
  # 转录结果稳定性管理
32
- self._transcrible_analysis = TranscriptStabilityAnalyzer()
33
  self._translate_pipe = pipe
34
 
35
  # 音频处理相关
@@ -43,7 +43,7 @@ class WhisperTranscriptionService(ServeClientBase):
43
 
44
  # 发送就绪状态
45
  self.send_ready_state()
46
-
47
  # 启动处理线程
48
  self._translate_thread_stop = threading.Event()
49
  self._frame_processing_thread_stop = threading.Event()
@@ -75,6 +75,7 @@ class WhisperTranscriptionService(ServeClientBase):
75
  self.source_language = source_lang
76
  self.target_language = target_lang
77
  self.text_separator = self._get_text_separator(source_lang)
 
78
 
79
  def add_audio_frames(self, frame_np: np.ndarray) -> None:
80
  """添加音频帧到处理队列"""
@@ -209,8 +210,7 @@ class WhisperTranscriptionService(ServeClientBase):
209
  if not segments:
210
  return
211
 
212
- for ana_result in self._transcrible_analysis.analysis(
213
- self.text_separator,segments, len(audio_buffer)/self.sample_rate):
214
  if (cut_index :=ana_result.cut_index)>0:
215
  # 更新音频缓冲区,移除已处理部分
216
  self._update_audio_buffer(cut_index)
 
29
  self.target_language = dst_lang # 目标翻译语言
30
 
31
  # 转录结果稳定性管理
32
+
33
  self._translate_pipe = pipe
34
 
35
  # 音频处理相关
 
43
 
44
  # 发送就绪状态
45
  self.send_ready_state()
46
+ self._transcrible_analysis = None
47
  # 启动处理线程
48
  self._translate_thread_stop = threading.Event()
49
  self._frame_processing_thread_stop = threading.Event()
 
75
  self.source_language = source_lang
76
  self.target_language = target_lang
77
  self.text_separator = self._get_text_separator(source_lang)
78
+ self._transcrible_analysis = TranscriptStabilityAnalyzer(self.text_separator)
79
 
80
  def add_audio_frames(self, frame_np: np.ndarray) -> None:
81
  """添加音频帧到处理队列"""
 
210
  if not segments:
211
  return
212
 
213
+ for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate):
 
214
  if (cut_index :=ana_result.cut_index)>0:
215
  # 更新音频缓冲区,移除已处理部分
216
  self._update_audio_buffer(cut_index)