Xin Zhang
[fix]: whisper_full_with_state: input is too short - 990 ms < 1000 ms. consider padding the input audio with silence.
1b8024b | import queue | |
| import threading | |
| import time | |
| from logging import getLogger | |
| import asyncio | |
| import numpy as np | |
| import config | |
| import collections | |
| from api_model import TransResult, Message | |
| from .utils import log_block, start_thread, get_text_separator, filter_words | |
| from .translatepipes import TranslatePipes | |
| from transcribe.pipelines import MetaItem | |
| logger = getLogger("TranscriptionService") | |
| class WhisperTranscriptionService: | |
| """ | |
| Whisper语音转录服务类,处理音频流转录和翻译 | |
| """ | |
| def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None): | |
| print('>>>>>>>>>>>>>>>> init service >>>>>>>>>>>>>>>>>>>>>>') | |
| print('src_lang:', language) | |
| self.source_language = language # 源语言 | |
| self.target_language = dst_lang # 目标翻译语言 | |
| self.client_uid = client_uid | |
| # 转录结果稳定性管理 | |
| self.websocket = websocket | |
| self.translate_pipe = pipe | |
| # 音频处理相关 | |
| self.sample_rate = config.SAMPLE_RATE | |
| self.lock = threading.Lock() | |
| # 文本分隔符,根据语言设置 | |
| self.text_separator = get_text_separator(language) | |
| self.loop = asyncio.get_event_loop() | |
| # 原始音频队列 | |
| self.frame_queue = queue.Queue() | |
| # 音频队列缓冲区 | |
| self.frames_np = np.array([], dtype=np.float32) | |
| # 音频开始的时间点 用于约束最小断句时间 | |
| self.frames_np_start_timestamp = None | |
| # 完整音频队列 | |
| self.full_segments_queue = collections.deque() | |
| # 启动处理线程 | |
| self._stop = threading.Event() | |
| self.translate_thread = start_thread(self._transcription_processing_loop) | |
| self.frame_processing_thread = start_thread(self._read_frame_processing_loop) | |
| # 行号 | |
| self.row_number = 0 | |
| def add_frames(self, frame_np: np.ndarray) -> None: | |
| """添加音频帧到处理队列""" | |
| self.frame_queue.put(frame_np) | |
| def _apply_voice_activity_detection(self, frame_np:np.array): | |
| """应用语音活动检测来优化音频缓冲区""" | |
| processed_audio = self.translate_pipe.voice_detect(frame_np.tobytes()) | |
| speech_audio = np.frombuffer(processed_audio.audio, dtype=np.float32) | |
| speech_status = processed_audio.speech_status | |
| return speech_audio, speech_status | |
| def _read_frame_processing_loop(self) -> None: | |
| """从队列获取音频帧并合并到缓冲区""" | |
| while not self._stop.is_set(): | |
| try: | |
| frame_np = self.frame_queue.get(timeout=0.1) | |
| frame_np, speech_status = self._apply_voice_activity_detection(frame_np) | |
| if frame_np is None: | |
| continue | |
| # logger.critical(f"frame np:{frame_np.shape}, {speech_status}") | |
| with self.lock: | |
| self.frames_np = np.append(self.frames_np, frame_np) | |
| # 音频开始时间节点 用来统计时间来 达到最小断句时间长度 | |
| if speech_status == "START" and self.frames_np_start_timestamp is None: | |
| self.frames_np_start_timestamp = time.time() | |
| # 音频最长时间缓冲区限制,超过了就强制断句 | |
| if len(self.frames_np) >= self.sample_rate * config.MAX_SPEECH_DURATION_S: | |
| audio_array=self.frames_np.copy() | |
| self.full_segments_queue.appendleft(audio_array) # 根据时间是否满足三秒长度 来整合音频块 | |
| self.frames_np_start_timestamp = time.time() | |
| self.frames_np = np.array([], dtype=np.float32) | |
| # 音频结束信号的时候 整合当前缓冲区 | |
| # START -- END -- START -- END 通常 | |
| # START -- END -- END end块带有音频信息的通常是4096内断的一个短音 | |
| elif speech_status == "END" and len(self.frames_np) > 0 and self.frames_np_start_timestamp: | |
| time_diff = time.time() - self.frames_np_start_timestamp | |
| if time_diff >= config.FRAME_SCOPE_TIME_THRESHOLD: | |
| audio_array=self.frames_np.copy() | |
| self.full_segments_queue.appendleft(audio_array) # 根据时间是否满足三秒长度 来整合音频块 | |
| self.frames_np_start_timestamp = None | |
| self.frames_np = np.array([], dtype=np.float32) | |
| else: | |
| logger.debug(f"🥳 当前时间与上一句的时间差: {time_diff:.2f}s,继续保留在缓冲区") | |
| except queue.Empty: | |
| pass | |
| def _transcription_processing_loop(self) -> None: | |
| """主转录处理循环""" | |
| frame_epoch = 1 | |
| while not self._stop.is_set(): | |
| if len(self.frames_np) ==0: | |
| time.sleep(0.1) | |
| continue | |
| with self.lock: | |
| if len(self.full_segments_queue) > 0: | |
| audio_buffer = self.full_segments_queue.pop() | |
| partial = False | |
| else: | |
| audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度 | |
| partial = True | |
| if len(audio_buffer) < int(self.sample_rate): | |
| # Add a small buffer (e.g., 10ms worth of samples) to be safe | |
| padding_samples = int(self.sample_rate * 0.01) # e.g., 160 samples for 10ms at 16kHz | |
| target_length = self.sample_rate + padding_samples | |
| silence_audio = np.zeros(target_length, dtype=np.float32) | |
| # Ensure we don't try to copy more data than exists if audio_buffer is very short | |
| copy_length = min(len(audio_buffer), target_length) | |
| silence_audio[-copy_length:] = audio_buffer[-copy_length:] # Copy from the end of audio_buffer | |
| audio_buffer = silence_audio | |
| elif len(audio_buffer) > self.sample_rate * config.MAX_SPEECH_DURATION_S: | |
| # If buffer is too long even without padding, truncate it (optional, depends on desired behavior) | |
| # This case might already be handled elsewhere, but good to consider | |
| audio_buffer = audio_buffer[:int(self.sample_rate * config.MAX_SPEECH_DURATION_S)] | |
| logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s") | |
| meta_item = self._transcribe_audio(audio_buffer) | |
| segments = meta_item.segments | |
| logger.debug(f"Segments: {segments}") | |
| segments = filter_words(segments) | |
| if len(segments): | |
| seg_text = self.text_separator.join(seg.text for seg in segments) | |
| if not seg_text.strip(): # 过滤空字符 | |
| continue | |
| # 整行 | |
| if not partial: | |
| translated_content = self._translate_text_large(seg_text) | |
| self.row_number += 1 | |
| frame_epoch = 1 | |
| else: | |
| translated_content = self._translate_text(seg_text) | |
| frame_epoch += 1 | |
| result = TransResult( | |
| seg_id=self.row_number, | |
| context=seg_text, | |
| from_=self.source_language, | |
| to=self.target_language, | |
| tran_content=translated_content, | |
| partial=partial | |
| ) | |
| self._send_result_to_client(result) | |
| def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem: | |
| """转录音频并返回转录片段""" | |
| log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s") | |
| result = self.translate_pipe.transcribe(audio_buffer.tobytes(), self.source_language) | |
| log_block("📝 transcribe output", f"{self.text_separator.join(seg.text for seg in result.segments)}", "") | |
| return result | |
| def _translate_text(self, text: str) -> str: | |
| """将文本翻译为目标语言""" | |
| if not text.strip(): | |
| return "" | |
| log_block("🐧 Translation input ", f"{text}") | |
| result = self.translate_pipe.translate(text, self.source_language, self.target_language) | |
| translated_text = result.translate_content | |
| log_block("🐧 Translation out ", f"{translated_text}") | |
| return translated_text | |
| def _translate_text_large(self, text: str) -> str: | |
| """将文本翻译为目标语言""" | |
| if not text.strip(): | |
| return "" | |
| log_block("Translation input", f"{text}") | |
| result = self.translate_pipe.translate_large(text, self.source_language, self.target_language) | |
| translated_text = result.translate_content | |
| log_block("Translation large model output", f"{translated_text}") | |
| return translated_text | |
| def _send_result_to_client(self, result: TransResult) -> None: | |
| """发送翻译结果到客户端""" | |
| try: | |
| message = Message(result=result, request_id=self.client_uid).model_dump_json(by_alias=True) | |
| coro = self.websocket.send_text(message) | |
| future = asyncio.run_coroutine_threadsafe(coro, self.loop) | |
| future.add_done_callback(lambda fut: fut.exception() and self.stop()) | |
| except RuntimeError: | |
| self.stop() | |
| except Exception as e: | |
| logger.error(f"Error sending result to client: {e}") | |
| def stop(self) -> None: | |
| """停止所有处理线程并清理资源""" | |
| self._stop.set() | |
| logger.info(f"Stopping transcription service for client: {self.client_uid}") | |