liumaolin commited on
Commit
2534744
·
1 Parent(s): 83ef092

Integrate WebSocket support: add `/api/v1/ws` endpoint, enable real-time message handling via `websocket_message_queue`, and refactor services and models to support WebSocket-based question and answer updates.

Browse files
src/VoiceDialogue/api/app.py CHANGED
@@ -11,7 +11,7 @@ from .core.config import AppConfig
11
  from .core.lifespan import lifespan
12
  from .middleware.logging import LoggingMiddleware
13
  from .middleware.rate_limit import RateLimitMiddleware
14
- from .routes import tts_routes, asr_routes, system_routes
15
 
16
  # 配置日志
17
  logging.basicConfig(
@@ -63,9 +63,10 @@ def _register_routes(app: FastAPI):
63
  v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
64
  v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
65
  v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
66
-
67
  app.include_router(v1_router)
68
 
 
 
69
  # 根路径和健康检查
70
  _register_health_routes(app)
71
 
 
11
  from .core.lifespan import lifespan
12
  from .middleware.logging import LoggingMiddleware
13
  from .middleware.rate_limit import RateLimitMiddleware
14
+ from .routes import tts_routes, asr_routes, system_routes, websocket_routes
15
 
16
  # 配置日志
17
  logging.basicConfig(
 
63
  v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
64
  v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
65
  v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
 
66
  app.include_router(v1_router)
67
 
68
+ app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
69
+
70
  # 根路径和健康检查
71
  _register_health_routes(app)
72
 
src/VoiceDialogue/api/core/service_factories.py CHANGED
@@ -1,6 +1,6 @@
1
  from core.constants import (
2
  transcribed_text_queue, text_input_queue, audio_output_queue,
3
- audio_frames_queue, user_voice_queue
4
  )
5
  from services.audio import EchoCancellingAudioCapture, TTSAudioGenerator, AudioStreamPlayer
6
  from services.audio.generators import BaseTTSConfig, tts_config_registry
@@ -41,7 +41,8 @@ class ServiceFactories:
41
  """创建LLM文本生成服务"""
42
  return LLMResponseGenerator(
43
  user_question_queue=transcribed_text_queue,
44
- generated_answer_queue=text_input_queue
 
45
  )
46
 
47
  @staticmethod
@@ -59,7 +60,10 @@ class ServiceFactories:
59
  @staticmethod
60
  def create_audio_player() -> AudioStreamPlayer:
61
  """创建音频播放服务"""
62
- return AudioStreamPlayer(audio_playing_queue=audio_output_queue)
 
 
 
63
 
64
 
65
  def get_core_voice_service_definitions(system_language: str, tts_config: BaseTTSConfig = None) -> list:
 
1
  from core.constants import (
2
  transcribed_text_queue, text_input_queue, audio_output_queue,
3
+ audio_frames_queue, user_voice_queue, websocket_message_queue
4
  )
5
  from services.audio import EchoCancellingAudioCapture, TTSAudioGenerator, AudioStreamPlayer
6
  from services.audio.generators import BaseTTSConfig, tts_config_registry
 
41
  """创建LLM文本生成服务"""
42
  return LLMResponseGenerator(
43
  user_question_queue=transcribed_text_queue,
44
+ generated_answer_queue=text_input_queue,
45
+ websocket_message_queue=websocket_message_queue,
46
  )
47
 
48
  @staticmethod
 
60
  @staticmethod
61
  def create_audio_player() -> AudioStreamPlayer:
62
  """创建音频播放服务"""
63
+ return AudioStreamPlayer(
64
+ audio_playing_queue=audio_output_queue,
65
+ websocket_message_queue=websocket_message_queue
66
+ )
67
 
68
 
69
  def get_core_voice_service_definitions(system_language: str, tts_config: BaseTTSConfig = None) -> list:
src/VoiceDialogue/api/routes/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from . import tts_routes, asr_routes, system_routes
2
 
3
- __all__ = ["tts_routes", "asr_routes", "system_routes"]
 
1
+ from . import tts_routes, asr_routes, system_routes, websocket_routes
2
 
3
+ __all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes"]
src/VoiceDialogue/api/routes/websocket_routes.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from queue import Empty
3
+
4
+
5
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
6
+
7
+ from core.constants import websocket_message_queue, session_manager
8
+
9
+ ws = APIRouter()
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @ws.websocket("/api/v1/ws")
14
+ async def websocket_endpoint(websocket: WebSocket):
15
+ """WebSocket连接端点"""
16
+ try:
17
+ # 建立连接
18
+ await websocket.accept()
19
+ # 保持连接活跃
20
+ while True:
21
+ try:
22
+ message = await websocket_message_queue.get()
23
+ except Empty:
24
+ continue
25
+
26
+ if message.session_id != session_manager.current_id:
27
+ continue
28
+
29
+ await websocket.send_json(message.model_dump())
30
+
31
+ except WebSocketDisconnect:
32
+ pass
33
+ except Exception as e:
34
+ logger.error(f"WebSocket连接异常: {e}")
src/VoiceDialogue/core/constants.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import multiprocessing
2
  import threading
3
  from collections import OrderedDict
@@ -36,6 +37,7 @@ user_voice_queue = multiprocessing.Queue()
36
  transcribed_text_queue = multiprocessing.Queue()
37
  text_input_queue = multiprocessing.Queue()
38
  audio_output_queue = multiprocessing.Queue()
 
39
 
40
  # ======================= 全局状态实例 =======================
41
 
 
1
+ import asyncio
2
  import multiprocessing
3
  import threading
4
  from collections import OrderedDict
 
37
  transcribed_text_queue = multiprocessing.Queue()
38
  text_input_queue = multiprocessing.Queue()
39
  audio_output_queue = multiprocessing.Queue()
40
+ websocket_message_queue = asyncio.Queue()
41
 
42
  # ======================= 全局状态实例 =======================
43
 
src/VoiceDialogue/models/__init__.py CHANGED
@@ -1 +1,7 @@
1
- from .voice_task import VoiceTask
 
 
 
 
 
 
 
1
+ from .voice_task import VoiceTask, QuestionDisplayMessage, AnswerDisplayMessage
2
+
3
+ __all__ = (
4
+ 'VoiceTask',
5
+ 'QuestionDisplayMessage',
6
+ 'AnswerDisplayMessage'
7
+ )
src/VoiceDialogue/models/voice_task.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  from pydantic import BaseModel, Field
3
 
@@ -30,3 +32,25 @@ class VoiceTask(BaseModel):
30
 
31
  class Config:
32
  arbitrary_types_allowed = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
  import numpy as np
4
  from pydantic import BaseModel, Field
5
 
 
32
 
33
  class Config:
34
  arbitrary_types_allowed = True
35
+
36
+
37
+ class DisplayMessageType(str, Enum):
38
+ QUESTION = 'question'
39
+ ANSWER = 'answer'
40
+
41
+
42
+ class BaseDisplayMessage(BaseModel):
43
+ message_type: DisplayMessageType
44
+ session_id: str
45
+ task_id: str
46
+
47
+
48
+ class QuestionDisplayMessage(BaseDisplayMessage):
49
+ message_type: DisplayMessageType = DisplayMessageType.QUESTION
50
+ question: str
51
+
52
+
53
+ class AnswerDisplayMessage(BaseDisplayMessage):
54
+ message_type: DisplayMessageType = DisplayMessageType.ANSWER
55
+ answer_index: int
56
+ answer: str
src/VoiceDialogue/services/audio/generator.py CHANGED
@@ -19,8 +19,12 @@ class TTSAudioGenerator(BaseThread):
19
  4. 将生成的音频任务放入音频队列中
20
  """
21
 
22
- def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
23
- text_input_queue: Queue, audio_output_queue: Queue, tts_config: BaseTTSConfig):
 
 
 
 
24
  """
25
  初始化TTS音频生成器
26
 
 
19
  4. 将生成的音频任务放入音频队列中
20
  """
21
 
22
+ def __init__(
23
+ self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
24
+ text_input_queue: Queue,
25
+ audio_output_queue: Queue,
26
+ tts_config: BaseTTSConfig,
27
+ ):
28
  """
29
  初始化TTS音频生成器
30
 
src/VoiceDialogue/services/audio/player.py CHANGED
@@ -11,16 +11,20 @@ from core.constants import (
11
  user_still_speaking_event, voice_state_manager, dropped_audio_cache, chat_history_cache,
12
  silence_over_threshold_event
13
  )
14
- from models.voice_task import VoiceTask
15
 
16
 
17
  class AudioStreamPlayer(BaseThread):
18
  """音频流播放器 - 负责播放生成的音频并管理播放状态"""
19
 
20
- def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
21
- audio_playing_queue):
 
 
 
22
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
23
  self.audio_playing_queue: Queue = audio_playing_queue
 
24
 
25
  def run(self):
26
  self.is_ready = True
@@ -56,6 +60,15 @@ class AudioStreamPlayer(BaseThread):
56
  if answer_id not in voice_state_manager.waiting_second_answer_mapping:
57
  continue
58
 
 
 
 
 
 
 
 
 
 
59
  # now = time.time()
60
  # print(
61
  # f'整体耗时: {(now - voice_task.send_time):.2f}\n'
 
11
  user_still_speaking_event, voice_state_manager, dropped_audio_cache, chat_history_cache,
12
  silence_over_threshold_event
13
  )
14
+ from models.voice_task import VoiceTask, AnswerDisplayMessage
15
 
16
 
17
  class AudioStreamPlayer(BaseThread):
18
  """音频流播放器 - 负责播放生成的音频并管理播放状态"""
19
 
20
+ def __init__(
21
+ self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
22
+ audio_playing_queue: Queue,
23
+ websocket_message_queue: Queue,
24
+ ):
25
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
26
  self.audio_playing_queue: Queue = audio_playing_queue
27
+ self.websocket_message_queue: Queue = websocket_message_queue
28
 
29
  def run(self):
30
  self.is_ready = True
 
60
  if answer_id not in voice_state_manager.waiting_second_answer_mapping:
61
  continue
62
 
63
+ if self.websocket_message_queue:
64
+ self.websocket_message_queue.put_nowait(
65
+ AnswerDisplayMessage(
66
+ session_id=voice_task.session_id,
67
+ task_id=task_id,
68
+ answer_index=voice_task.answer_index,
69
+ answer=voice_task.answer_sentence,
70
+ )
71
+ )
72
  # now = time.time()
73
  # print(
74
  # f'整体耗时: {(now - voice_task.send_time):.2f}\n'
src/VoiceDialogue/services/text/generator.py CHANGED
@@ -8,7 +8,7 @@ from langchain_core.chat_history import InMemoryChatMessageHistory
8
  from config import paths
9
  from core.base import BaseThread
10
  from core.constants import chat_history_cache
11
- from models.voice_task import VoiceTask
12
  from services.text.processor import preprocess_sentence_text, \
13
  create_langchain_chat_llamacpp_instance, create_langchain_pipeline, warmup_langchain_pipeline
14
 
@@ -26,14 +26,17 @@ ENGLISH_SYSTEM_PROMPT = ("You are an AI assistant skilled at simulating authenti
26
  class LLMResponseGenerator(BaseThread):
27
  """LLM 回答生成器 - 负责使用语言模型生成回答文本"""
28
 
29
- def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
30
- user_question_queue: Queue,
31
- generated_answer_queue: Queue
32
- ):
 
 
33
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
34
 
35
  self.user_question_queue = user_question_queue
36
  self.generated_answer_queue = generated_answer_queue
 
37
 
38
  self.english_sentence_end_marks = {'!', '?', '.', ',', ':', ';'}
39
  self.chinese_sentence_end_marks = {',', '。', '!', '?', ':', ';', '、'}
@@ -121,6 +124,15 @@ class LLMResponseGenerator(BaseThread):
121
 
122
  user_question = voice_task.transcribed_text
123
  print(f'用户问题: {user_question}')
 
 
 
 
 
 
 
 
 
124
  voice_task.llm_start_time = time.time()
125
 
126
  system_prompt = self._get_prompt_by_language(voice_task.language)
 
8
  from config import paths
9
  from core.base import BaseThread
10
  from core.constants import chat_history_cache
11
+ from models.voice_task import VoiceTask, QuestionDisplayMessage
12
  from services.text.processor import preprocess_sentence_text, \
13
  create_langchain_chat_llamacpp_instance, create_langchain_pipeline, warmup_langchain_pipeline
14
 
 
26
  class LLMResponseGenerator(BaseThread):
27
  """LLM 回答生成器 - 负责使用语言模型生成回答文本"""
28
 
29
+ def __init__(
30
+ self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
31
+ user_question_queue: Queue,
32
+ generated_answer_queue: Queue,
33
+ websocket_message_queue: Queue = None
34
+ ):
35
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
36
 
37
  self.user_question_queue = user_question_queue
38
  self.generated_answer_queue = generated_answer_queue
39
+ self.websocket_message_queue = websocket_message_queue
40
 
41
  self.english_sentence_end_marks = {'!', '?', '.', ',', ':', ';'}
42
  self.chinese_sentence_end_marks = {',', '。', '!', '?', ':', ';', '、'}
 
124
 
125
  user_question = voice_task.transcribed_text
126
  print(f'用户问题: {user_question}')
127
+ if self.websocket_message_queue:
128
+ self.websocket_message_queue.put_nowait(
129
+ QuestionDisplayMessage(
130
+ session_id=voice_task.session_id,
131
+ question=user_question,
132
+ task_id=voice_task.id,
133
+ )
134
+ )
135
+
136
  voice_task.llm_start_time = time.time()
137
 
138
  system_prompt = self._get_prompt_by_language(voice_task.language)