| import asyncio |
| import json |
| import queue |
| import threading |
| import time |
| from logging import getLogger |
| from typing import List, Optional, Iterator, Tuple, Any |
| import asyncio |
| import numpy as np |
| import config |
| |
| from api_model import TransResult, Message, DebugResult |
| from .server import ServeClientBase |
| from .utils import log_block, save_to_wave, TestDataWriter |
| from .translatepipes import TranslatePipes |
| from .strategy import ( |
| TranscriptStabilityAnalyzer, TranscriptToken) |
| import csv |
|
|
| logger = getLogger("TranscriptionService") |
|
|
|
|
| class WhisperTranscriptionService(ServeClientBase): |
| """ |
| Whisper语音转录服务类,处理音频流转录和翻译 |
| """ |
|
|
| def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None): |
| super().__init__(client_uid, websocket) |
| self.source_language = language |
| self.target_language = dst_lang |
|
|
| |
|
|
| self._translate_pipe = pipe |
|
|
| |
| self.sample_rate = 16000 |
| self.frames_np = None |
| self.lock = threading.Lock() |
| self._frame_queue = queue.Queue() |
| |
|
|
| |
| self.text_separator = self._get_text_separator(language) |
| self.loop = asyncio.get_event_loop() |
| |
| self.send_ready_state() |
| self._transcrible_analysis = None |
| |
| self._translate_thread_stop = threading.Event() |
| self._frame_processing_thread_stop = threading.Event() |
|
|
| self.translate_thread = self._start_thread(self._transcription_processing_loop) |
| self.frame_processing_thread = self._start_thread(self._frame_processing_loop) |
|
|
| |
| self._transcrible_time_cost = 0. |
| self._translate_time_cost = 0. |
| if config.TEST: |
| self._test_task_stop = threading.Event() |
| self._test_queue = queue.Queue() |
| self._test_thread = self._start_thread(self.test_data_loop) |
| |
| |
| |
| def test_data_loop(self): |
| writer = TestDataWriter() |
| while not self._test_task_stop.is_set(): |
| test_data = self._test_queue.get() |
| writer.write(test_data) |
|
|
|
|
| def _start_thread(self, target_function) -> threading.Thread: |
| """启动守护线程执行指定函数""" |
| thread = threading.Thread(target=target_function) |
| thread.daemon = True |
| thread.start() |
| return thread |
|
|
| def _get_text_separator(self, language: str) -> str: |
| """根据语言返回适当的文本分隔符""" |
| return "" if language == "zh" else " " |
|
|
| def send_ready_state(self) -> None: |
| """发送服务就绪状态消息""" |
| self.websocket.send(json.dumps({ |
| "uid": self.client_uid, |
| "message": self.SERVER_READY, |
| "backend": "whisper_transcription" |
| })) |
|
|
| def set_language(self, source_lang: str, target_lang: str) -> None: |
| """设置源语言和目标语言""" |
| self.source_language = source_lang |
| self.target_language = target_lang |
| self.text_separator = self._get_text_separator(source_lang) |
| self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator) |
|
|
| def add_audio_frames(self, frame_np: np.ndarray) -> None: |
| """添加音频帧到处理队列""" |
| self._frame_queue.put(frame_np) |
|
|
| def _frame_processing_loop(self) -> None: |
| """从队列获取音频帧并合并到缓冲区""" |
| while not self._frame_processing_thread_stop.is_set(): |
| try: |
| frame_np = self._frame_queue.get(timeout=0.1) |
| if frame_np is None: |
| logger.error("Received None frame, stopping thread") |
| with self.lock: |
| if self.frames_np is None: |
| self.frames_np = frame_np.copy() |
| else: |
| self.frames_np = np.append(self.frames_np, frame_np) |
| except queue.Empty: |
| pass |
|
|
| def _apply_voice_activity_detection(self) -> None: |
| """应用语音活动检测来优化音频缓冲区""" |
| with self.lock: |
| if self.frames_np is not None: |
| |
| frame = self.frames_np.copy() |
| processed_audio = self._translate_pipe.voice_detect(frame.tobytes()) |
| self.frames_np = np.frombuffer(processed_audio.audio, dtype=np.float32).copy() |
| return self.frames_np.copy() |
| |
| |
| |
|
|
| def _update_audio_buffer(self, offset: int) -> None: |
| """从音频缓冲区中移除已处理的部分""" |
| with self.lock: |
| if self.frames_np is not None and offset > 0: |
| |
| |
| self.frames_np = self.frames_np[offset:] |
| |
| |
| |
| |
|
|
|
|
| def _get_audio_for_processing(self) -> Optional[np.ndarray]: |
| """准备用于处理的音频块""" |
| |
| frame_np = self._apply_voice_activity_detection() |
| |
| |
| if frame_np is None: |
| return None |
|
|
| frames = frame_np.copy() |
|
|
| |
| if len(frames) <= 10: |
| |
| |
| return None |
| if len(frames) < self.sample_rate: |
| |
| silence_audio = np.zeros((self.sample_rate + 1000,), dtype=np.float32) |
| silence_audio[-len(frames):] = frames |
| return silence_audio.copy() |
|
|
| return frames |
|
|
| def _transcribe_audio(self, audio_buffer: np.ndarray) -> List[TranscriptToken]: |
| """转录音频并返回转录片段""" |
| log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s") |
| start_time = time.perf_counter() |
|
|
| result = self._translate_pipe.transcrible(audio_buffer.tobytes(), self.source_language) |
| segments = result.segments |
| time_diff = (time.perf_counter() - start_time) |
| logger.debug(f"📝 Transcrible Segments: {segments} ") |
| logger.debug(f"📝 Transcrible: {self.text_separator.join(seg.text for seg in segments)} ") |
| log_block("📝 Transcrible output", f"{self.text_separator.join(seg.text for seg in segments)}", "") |
| log_block("📝 Transcrible time", f"{time_diff:.3f}", "s") |
| self._transcrible_time_cost = round(time_diff, 3) |
| return [ |
| TranscriptToken(text=s.text, t0=s.t0, t1=s.t1) |
| for s in segments |
| ] |
|
|
| def _translate_text(self, text: str) -> str: |
| """将文本翻译为目标语言""" |
| if not text.strip(): |
| return "" |
|
|
| log_block("🐧 Translation input ", f"{text}") |
| start_time = time.perf_counter() |
|
|
| result = self._translate_pipe.translate(text, self.source_language, self.target_language) |
| translated_text = result.translate_content |
| time_diff = (time.perf_counter() - start_time) |
| log_block("🐧 Translation time ", f"{time_diff:.3f}", "s") |
| log_block("🐧 Translation out ", f"{translated_text}") |
| self._translate_time_cost = round(time_diff, 3) |
| return translated_text |
|
|
| def _translate_text_large(self, text: str) -> str: |
| """将文本翻译为目标语言""" |
| if not text.strip(): |
| return "" |
|
|
| log_block("Translation input", f"{text}") |
| start_time = time.perf_counter() |
|
|
| result = self._translate_pipe.translate_large(text, self.source_language, self.target_language) |
| translated_text = result.translate_content |
| time_diff = (time.perf_counter() - start_time) |
| log_block("Translation large model time ", f"{time_diff:.3f}", "s") |
| log_block("Translation large model output", f"{translated_text}") |
| self._translate_time_cost = round(time_diff, 3) |
| return translated_text |
|
|
|
|
|
|
| def _transcription_processing_loop(self) -> None: |
| """主转录处理循环""" |
| c = 0 |
| while not self._translate_thread_stop.is_set(): |
| if self.exit: |
| logger.info("Exiting transcription thread") |
| break |
|
|
| |
| if self.frames_np is None: |
| time.sleep(0.2) |
| logger.info("Waiting for audio data...") |
| continue |
|
|
| |
| audio_buffer = self._get_audio_for_processing() |
| if audio_buffer is None: |
| time.sleep(0.2) |
| continue |
| logger.debug(f"🥤 Buffer Length: {len(audio_buffer)/self.sample_rate:.2f} ") |
| |
|
|
| |
| segments = self._transcribe_audio(audio_buffer) |
|
|
| |
| for result in self._process_transcription_results(segments, audio_buffer): |
| self._send_result_to_client(result) |
|
|
| |
| |
|
|
| def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]: |
| """ |
| 处理转录结果,生成翻译结果 |
| |
| Returns: |
| TransResult对象的迭代器 |
| """ |
|
|
| if not segments: |
| return |
| start_time = time.perf_counter() |
| for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate): |
| if (cut_index :=ana_result.cut_index)>0: |
| |
| self._update_audio_buffer(cut_index) |
| if ana_result.partial(): |
| translated_context = self._translate_text(ana_result.context) |
| else: |
| translated_context = self._translate_text_large(ana_result.context) |
|
|
| yield TransResult( |
| seg_id=ana_result.seg_id, |
| context=ana_result.context, |
| from_=self.source_language, |
| to=self.target_language, |
| tran_content=translated_context, |
| partial=ana_result.partial() |
| ) |
| current_time = time.perf_counter() |
| time_diff = current_time - start_time |
| if config.TEST: |
| self._test_queue.put(DebugResult( |
| seg_id=ana_result.seg_id, |
| transcrible_time=self._transcrible_time_cost, |
| translate_time=self._translate_time_cost, |
| context=ana_result.context, |
| from_=self.source_language, |
| to=self.target_language, |
| tran_content=translated_context, |
| partial=ana_result.partial() |
| )) |
| log_block("🚦 Traffic times diff", round(time_diff, 2), 's') |
|
|
|
|
| def _send_result_to_client(self, result: TransResult) -> None: |
| """发送翻译结果到客户端""" |
| try: |
| message = Message(result=result, request_id=self.client_uid).model_dump_json(by_alias=True) |
| coro = self.websocket.send_text(message) |
| future = asyncio.run_coroutine_threadsafe(coro, self.loop) |
| future.add_done_callback(lambda fut: fut.exception() and self.stop()) |
| except RuntimeError: |
| self.stop() |
| except Exception as e: |
| logger.error(f"Error sending result to client: {e}") |
|
|
| def stop(self) -> None: |
| """停止所有处理线程并清理资源""" |
| self._translate_thread_stop.set() |
| self._frame_processing_thread_stop.set() |
| if config.TEST: |
| self._test_task_stop.set() |
| logger.info(f"Stopping transcription service for client: {self.client_uid}") |
|
|