liumaolin
commited on
Commit
·
2534744
1
Parent(s):
83ef092
Integrate WebSocket support: add `/api/v1/ws` endpoint, enable real-time message handling via `websocket_message_queue`, and refactor services and models to support WebSocket-based question and answer updates.
Browse files- src/VoiceDialogue/api/app.py +3 -2
- src/VoiceDialogue/api/core/service_factories.py +7 -3
- src/VoiceDialogue/api/routes/__init__.py +2 -2
- src/VoiceDialogue/api/routes/websocket_routes.py +34 -0
- src/VoiceDialogue/core/constants.py +2 -0
- src/VoiceDialogue/models/__init__.py +7 -1
- src/VoiceDialogue/models/voice_task.py +24 -0
- src/VoiceDialogue/services/audio/generator.py +6 -2
- src/VoiceDialogue/services/audio/player.py +16 -3
- src/VoiceDialogue/services/text/generator.py +17 -5
src/VoiceDialogue/api/app.py
CHANGED
|
@@ -11,7 +11,7 @@ from .core.config import AppConfig
|
|
| 11 |
from .core.lifespan import lifespan
|
| 12 |
from .middleware.logging import LoggingMiddleware
|
| 13 |
from .middleware.rate_limit import RateLimitMiddleware
|
| 14 |
-
from .routes import tts_routes, asr_routes, system_routes
|
| 15 |
|
| 16 |
# 配置日志
|
| 17 |
logging.basicConfig(
|
|
@@ -63,9 +63,10 @@ def _register_routes(app: FastAPI):
|
|
| 63 |
v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
|
| 64 |
v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
|
| 65 |
v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
|
| 66 |
-
|
| 67 |
app.include_router(v1_router)
|
| 68 |
|
|
|
|
|
|
|
| 69 |
# 根路径和健康检查
|
| 70 |
_register_health_routes(app)
|
| 71 |
|
|
|
|
| 11 |
from .core.lifespan import lifespan
|
| 12 |
from .middleware.logging import LoggingMiddleware
|
| 13 |
from .middleware.rate_limit import RateLimitMiddleware
|
| 14 |
+
from .routes import tts_routes, asr_routes, system_routes, websocket_routes
|
| 15 |
|
| 16 |
# 配置日志
|
| 17 |
logging.basicConfig(
|
|
|
|
| 63 |
v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
|
| 64 |
v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
|
| 65 |
v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
|
|
|
|
| 66 |
app.include_router(v1_router)
|
| 67 |
|
| 68 |
+
app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
|
| 69 |
+
|
| 70 |
# 根路径和健康检查
|
| 71 |
_register_health_routes(app)
|
| 72 |
|
src/VoiceDialogue/api/core/service_factories.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from core.constants import (
|
| 2 |
transcribed_text_queue, text_input_queue, audio_output_queue,
|
| 3 |
-
audio_frames_queue, user_voice_queue
|
| 4 |
)
|
| 5 |
from services.audio import EchoCancellingAudioCapture, TTSAudioGenerator, AudioStreamPlayer
|
| 6 |
from services.audio.generators import BaseTTSConfig, tts_config_registry
|
|
@@ -41,7 +41,8 @@ class ServiceFactories:
|
|
| 41 |
"""创建LLM文本生成服务"""
|
| 42 |
return LLMResponseGenerator(
|
| 43 |
user_question_queue=transcribed_text_queue,
|
| 44 |
-
generated_answer_queue=text_input_queue
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
@staticmethod
|
|
@@ -59,7 +60,10 @@ class ServiceFactories:
|
|
| 59 |
@staticmethod
|
| 60 |
def create_audio_player() -> AudioStreamPlayer:
|
| 61 |
"""创建音频播放服务"""
|
| 62 |
-
return AudioStreamPlayer(
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
def get_core_voice_service_definitions(system_language: str, tts_config: BaseTTSConfig = None) -> list:
|
|
|
|
| 1 |
from core.constants import (
|
| 2 |
transcribed_text_queue, text_input_queue, audio_output_queue,
|
| 3 |
+
audio_frames_queue, user_voice_queue, websocket_message_queue
|
| 4 |
)
|
| 5 |
from services.audio import EchoCancellingAudioCapture, TTSAudioGenerator, AudioStreamPlayer
|
| 6 |
from services.audio.generators import BaseTTSConfig, tts_config_registry
|
|
|
|
| 41 |
"""创建LLM文本生成服务"""
|
| 42 |
return LLMResponseGenerator(
|
| 43 |
user_question_queue=transcribed_text_queue,
|
| 44 |
+
generated_answer_queue=text_input_queue,
|
| 45 |
+
websocket_message_queue=websocket_message_queue,
|
| 46 |
)
|
| 47 |
|
| 48 |
@staticmethod
|
|
|
|
| 60 |
@staticmethod
|
| 61 |
def create_audio_player() -> AudioStreamPlayer:
|
| 62 |
"""创建音频播放服务"""
|
| 63 |
+
return AudioStreamPlayer(
|
| 64 |
+
audio_playing_queue=audio_output_queue,
|
| 65 |
+
websocket_message_queue=websocket_message_queue
|
| 66 |
+
)
|
| 67 |
|
| 68 |
|
| 69 |
def get_core_voice_service_definitions(system_language: str, tts_config: BaseTTSConfig = None) -> list:
|
src/VoiceDialogue/api/routes/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from . import tts_routes, asr_routes, system_routes
|
| 2 |
|
| 3 |
-
__all__ = ["tts_routes", "asr_routes", "system_routes"]
|
|
|
|
| 1 |
+
from . import tts_routes, asr_routes, system_routes, websocket_routes
|
| 2 |
|
| 3 |
+
__all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes"]
|
src/VoiceDialogue/api/routes/websocket_routes.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from queue import Empty
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 6 |
+
|
| 7 |
+
from core.constants import websocket_message_queue, session_manager
|
| 8 |
+
|
| 9 |
+
ws = APIRouter()
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@ws.websocket("/api/v1/ws")
|
| 14 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 15 |
+
"""WebSocket连接端点"""
|
| 16 |
+
try:
|
| 17 |
+
# 建立连接
|
| 18 |
+
await websocket.accept()
|
| 19 |
+
# 保持连接活跃
|
| 20 |
+
while True:
|
| 21 |
+
try:
|
| 22 |
+
message = await websocket_message_queue.get()
|
| 23 |
+
except Empty:
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
if message.session_id != session_manager.current_id:
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
await websocket.send_json(message.model_dump())
|
| 30 |
+
|
| 31 |
+
except WebSocketDisconnect:
|
| 32 |
+
pass
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error(f"WebSocket连接异常: {e}")
|
src/VoiceDialogue/core/constants.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import multiprocessing
|
| 2 |
import threading
|
| 3 |
from collections import OrderedDict
|
|
@@ -36,6 +37,7 @@ user_voice_queue = multiprocessing.Queue()
|
|
| 36 |
transcribed_text_queue = multiprocessing.Queue()
|
| 37 |
text_input_queue = multiprocessing.Queue()
|
| 38 |
audio_output_queue = multiprocessing.Queue()
|
|
|
|
| 39 |
|
| 40 |
# ======================= 全局状态实例 =======================
|
| 41 |
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import multiprocessing
|
| 3 |
import threading
|
| 4 |
from collections import OrderedDict
|
|
|
|
| 37 |
transcribed_text_queue = multiprocessing.Queue()
|
| 38 |
text_input_queue = multiprocessing.Queue()
|
| 39 |
audio_output_queue = multiprocessing.Queue()
|
| 40 |
+
websocket_message_queue = asyncio.Queue()
|
| 41 |
|
| 42 |
# ======================= 全局状态实例 =======================
|
| 43 |
|
src/VoiceDialogue/models/__init__.py
CHANGED
|
@@ -1 +1,7 @@
|
|
| 1 |
-
from .voice_task import VoiceTask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .voice_task import VoiceTask, QuestionDisplayMessage, AnswerDisplayMessage
|
| 2 |
+
|
| 3 |
+
__all__ = (
|
| 4 |
+
'VoiceTask',
|
| 5 |
+
'QuestionDisplayMessage',
|
| 6 |
+
'AnswerDisplayMessage'
|
| 7 |
+
)
|
src/VoiceDialogue/models/voice_task.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
|
|
@@ -30,3 +32,25 @@ class VoiceTask(BaseModel):
|
|
| 30 |
|
| 31 |
class Config:
|
| 32 |
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
import numpy as np
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
|
|
|
|
| 32 |
|
| 33 |
class Config:
|
| 34 |
arbitrary_types_allowed = True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DisplayMessageType(str, Enum):
|
| 38 |
+
QUESTION = 'question'
|
| 39 |
+
ANSWER = 'answer'
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BaseDisplayMessage(BaseModel):
|
| 43 |
+
message_type: DisplayMessageType
|
| 44 |
+
session_id: str
|
| 45 |
+
task_id: str
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class QuestionDisplayMessage(BaseDisplayMessage):
|
| 49 |
+
message_type: DisplayMessageType = DisplayMessageType.QUESTION
|
| 50 |
+
question: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AnswerDisplayMessage(BaseDisplayMessage):
|
| 54 |
+
message_type: DisplayMessageType = DisplayMessageType.ANSWER
|
| 55 |
+
answer_index: int
|
| 56 |
+
answer: str
|
src/VoiceDialogue/services/audio/generator.py
CHANGED
|
@@ -19,8 +19,12 @@ class TTSAudioGenerator(BaseThread):
|
|
| 19 |
4. 将生成的音频任务放入音频队列中
|
| 20 |
"""
|
| 21 |
|
| 22 |
-
def __init__(
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
初始化TTS音频生成器
|
| 26 |
|
|
|
|
| 19 |
4. 将生成的音频任务放入音频队列中
|
| 20 |
"""
|
| 21 |
|
| 22 |
+
def __init__(
|
| 23 |
+
self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
|
| 24 |
+
text_input_queue: Queue,
|
| 25 |
+
audio_output_queue: Queue,
|
| 26 |
+
tts_config: BaseTTSConfig,
|
| 27 |
+
):
|
| 28 |
"""
|
| 29 |
初始化TTS音频生成器
|
| 30 |
|
src/VoiceDialogue/services/audio/player.py
CHANGED
|
@@ -11,16 +11,20 @@ from core.constants import (
|
|
| 11 |
user_still_speaking_event, voice_state_manager, dropped_audio_cache, chat_history_cache,
|
| 12 |
silence_over_threshold_event
|
| 13 |
)
|
| 14 |
-
from models.voice_task import VoiceTask
|
| 15 |
|
| 16 |
|
| 17 |
class AudioStreamPlayer(BaseThread):
|
| 18 |
"""音频流播放器 - 负责播放生成的音频并管理播放状态"""
|
| 19 |
|
| 20 |
-
def __init__(
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
super().__init__(group, target, name, args, kwargs, daemon=daemon)
|
| 23 |
self.audio_playing_queue: Queue = audio_playing_queue
|
|
|
|
| 24 |
|
| 25 |
def run(self):
|
| 26 |
self.is_ready = True
|
|
@@ -56,6 +60,15 @@ class AudioStreamPlayer(BaseThread):
|
|
| 56 |
if answer_id not in voice_state_manager.waiting_second_answer_mapping:
|
| 57 |
continue
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# now = time.time()
|
| 60 |
# print(
|
| 61 |
# f'整体耗时: {(now - voice_task.send_time):.2f}\n'
|
|
|
|
| 11 |
user_still_speaking_event, voice_state_manager, dropped_audio_cache, chat_history_cache,
|
| 12 |
silence_over_threshold_event
|
| 13 |
)
|
| 14 |
+
from models.voice_task import VoiceTask, AnswerDisplayMessage
|
| 15 |
|
| 16 |
|
| 17 |
class AudioStreamPlayer(BaseThread):
|
| 18 |
"""音频流播放器 - 负责播放生成的音频并管理播放状态"""
|
| 19 |
|
| 20 |
+
def __init__(
|
| 21 |
+
self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
|
| 22 |
+
audio_playing_queue: Queue,
|
| 23 |
+
websocket_message_queue: Queue,
|
| 24 |
+
):
|
| 25 |
super().__init__(group, target, name, args, kwargs, daemon=daemon)
|
| 26 |
self.audio_playing_queue: Queue = audio_playing_queue
|
| 27 |
+
self.websocket_message_queue: Queue = websocket_message_queue
|
| 28 |
|
| 29 |
def run(self):
|
| 30 |
self.is_ready = True
|
|
|
|
| 60 |
if answer_id not in voice_state_manager.waiting_second_answer_mapping:
|
| 61 |
continue
|
| 62 |
|
| 63 |
+
if self.websocket_message_queue:
|
| 64 |
+
self.websocket_message_queue.put_nowait(
|
| 65 |
+
AnswerDisplayMessage(
|
| 66 |
+
session_id=voice_task.session_id,
|
| 67 |
+
task_id=task_id,
|
| 68 |
+
answer_index=voice_task.answer_index,
|
| 69 |
+
answer=voice_task.answer_sentence,
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
# now = time.time()
|
| 73 |
# print(
|
| 74 |
# f'整体耗时: {(now - voice_task.send_time):.2f}\n'
|
src/VoiceDialogue/services/text/generator.py
CHANGED
|
@@ -8,7 +8,7 @@ from langchain_core.chat_history import InMemoryChatMessageHistory
|
|
| 8 |
from config import paths
|
| 9 |
from core.base import BaseThread
|
| 10 |
from core.constants import chat_history_cache
|
| 11 |
-
from models.voice_task import VoiceTask
|
| 12 |
from services.text.processor import preprocess_sentence_text, \
|
| 13 |
create_langchain_chat_llamacpp_instance, create_langchain_pipeline, warmup_langchain_pipeline
|
| 14 |
|
|
@@ -26,14 +26,17 @@ ENGLISH_SYSTEM_PROMPT = ("You are an AI assistant skilled at simulating authenti
|
|
| 26 |
class LLMResponseGenerator(BaseThread):
|
| 27 |
"""LLM 回答生成器 - 负责使用语言模型生成回答文本"""
|
| 28 |
|
| 29 |
-
def __init__(
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
super().__init__(group, target, name, args, kwargs, daemon=daemon)
|
| 34 |
|
| 35 |
self.user_question_queue = user_question_queue
|
| 36 |
self.generated_answer_queue = generated_answer_queue
|
|
|
|
| 37 |
|
| 38 |
self.english_sentence_end_marks = {'!', '?', '.', ',', ':', ';'}
|
| 39 |
self.chinese_sentence_end_marks = {',', '。', '!', '?', ':', ';', '、'}
|
|
@@ -121,6 +124,15 @@ class LLMResponseGenerator(BaseThread):
|
|
| 121 |
|
| 122 |
user_question = voice_task.transcribed_text
|
| 123 |
print(f'用户问题: {user_question}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
voice_task.llm_start_time = time.time()
|
| 125 |
|
| 126 |
system_prompt = self._get_prompt_by_language(voice_task.language)
|
|
|
|
| 8 |
from config import paths
|
| 9 |
from core.base import BaseThread
|
| 10 |
from core.constants import chat_history_cache
|
| 11 |
+
from models.voice_task import VoiceTask, QuestionDisplayMessage
|
| 12 |
from services.text.processor import preprocess_sentence_text, \
|
| 13 |
create_langchain_chat_llamacpp_instance, create_langchain_pipeline, warmup_langchain_pipeline
|
| 14 |
|
|
|
|
| 26 |
class LLMResponseGenerator(BaseThread):
|
| 27 |
"""LLM 回答生成器 - 负责使用语言模型生成回答文本"""
|
| 28 |
|
| 29 |
+
def __init__(
|
| 30 |
+
self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None,
|
| 31 |
+
user_question_queue: Queue,
|
| 32 |
+
generated_answer_queue: Queue,
|
| 33 |
+
websocket_message_queue: Queue = None
|
| 34 |
+
):
|
| 35 |
super().__init__(group, target, name, args, kwargs, daemon=daemon)
|
| 36 |
|
| 37 |
self.user_question_queue = user_question_queue
|
| 38 |
self.generated_answer_queue = generated_answer_queue
|
| 39 |
+
self.websocket_message_queue = websocket_message_queue
|
| 40 |
|
| 41 |
self.english_sentence_end_marks = {'!', '?', '.', ',', ':', ';'}
|
| 42 |
self.chinese_sentence_end_marks = {',', '。', '!', '?', ':', ';', '、'}
|
|
|
|
| 124 |
|
| 125 |
user_question = voice_task.transcribed_text
|
| 126 |
print(f'用户问题: {user_question}')
|
| 127 |
+
if self.websocket_message_queue:
|
| 128 |
+
self.websocket_message_queue.put_nowait(
|
| 129 |
+
QuestionDisplayMessage(
|
| 130 |
+
session_id=voice_task.session_id,
|
| 131 |
+
question=user_question,
|
| 132 |
+
task_id=voice_task.id,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
voice_task.llm_start_time = time.time()
|
| 137 |
|
| 138 |
system_prompt = self._get_prompt_by_language(voice_task.language)
|