liumaolin
commited on
Commit
·
15ab478
1
Parent(s):
1816130
添加设置相关的API路由
Browse files- 支持用户自定义提示词的获取、更新和重置功能
- 同时更新应用数据存储路径的获取逻辑。
src/voice_dialogue/api/app.py
CHANGED
|
@@ -11,7 +11,7 @@ from .core.config import AppConfig
|
|
| 11 |
from .core.lifespan import lifespan
|
| 12 |
from .middleware.logging import LoggingMiddleware
|
| 13 |
from .middleware.rate_limit import RateLimitMiddleware
|
| 14 |
-
from .routes import tts_routes, asr_routes, system_routes, websocket_routes
|
| 15 |
|
| 16 |
|
| 17 |
def create_app() -> FastAPI:
|
|
@@ -56,6 +56,7 @@ def _register_routes(app: FastAPI):
|
|
| 56 |
v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
|
| 57 |
v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
|
| 58 |
v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
|
|
|
|
| 59 |
app.include_router(v1_router)
|
| 60 |
|
| 61 |
app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
|
|
|
|
| 11 |
from .core.lifespan import lifespan
|
| 12 |
from .middleware.logging import LoggingMiddleware
|
| 13 |
from .middleware.rate_limit import RateLimitMiddleware
|
| 14 |
+
from .routes import tts_routes, asr_routes, system_routes, websocket_routes, settings_routes
|
| 15 |
|
| 16 |
|
| 17 |
def create_app() -> FastAPI:
|
|
|
|
| 56 |
v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
|
| 57 |
v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
|
| 58 |
v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
|
| 59 |
+
v1_router.include_router(settings_routes.router, prefix="/settings", tags=["设置管理"])
|
| 60 |
app.include_router(v1_router)
|
| 61 |
|
| 62 |
app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
|
src/voice_dialogue/api/routes/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from . import tts_routes, asr_routes, system_routes, websocket_routes
|
| 2 |
|
| 3 |
-
__all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes"]
|
|
|
|
| 1 |
+
from . import tts_routes, asr_routes, system_routes, websocket_routes, settings_routes
|
| 2 |
|
| 3 |
+
__all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes", "settings_routes"]
|
src/voice_dialogue/api/routes/settings_routes.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""设置相关的API路由"""
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
from voice_dialogue.config.llm_config import CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
|
| 8 |
+
from voice_dialogue.config.user_config import (
|
| 9 |
+
get_user_prompts, save_user_prompts, get_prompt, reset_prompts_to_default
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
router = APIRouter()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PromptsResponse(BaseModel):
|
| 16 |
+
"""获取 Prompts 的响应模型"""
|
| 17 |
+
chinese_prompt: str = Field(..., description="中文系统提示词")
|
| 18 |
+
english_prompt: str = Field(..., description="英文系统提示词")
|
| 19 |
+
is_custom: bool = Field(..., description="是否为用户自定义")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class UpdatePromptsRequest(BaseModel):
|
| 23 |
+
"""更新 Prompts 的请求模型"""
|
| 24 |
+
chinese_prompt: Optional[str] = Field(None, description="中文系统提示词")
|
| 25 |
+
english_prompt: Optional[str] = Field(None, description="英文系统提示词")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DefaultPromptsResponse(BaseModel):
|
| 29 |
+
"""默认 Prompts 的响应模型"""
|
| 30 |
+
chinese_prompt: str = Field(..., description="默认中文系统提示词")
|
| 31 |
+
english_prompt: str = Field(..., description="默认英文系统提示词")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@router.get("/settings/prompts", response_model=PromptsResponse, summary="获取当前生效的 Prompt")
|
| 35 |
+
async def get_current_prompts():
|
| 36 |
+
"""
|
| 37 |
+
获取当前系统中正在使用的中文和英文系统 Prompt
|
| 38 |
+
会融合用户自定义设置和系统默认值
|
| 39 |
+
"""
|
| 40 |
+
user_prompts = get_user_prompts()
|
| 41 |
+
is_custom = bool(user_prompts) # 如果有用户自定义配置,则认为是自定义的
|
| 42 |
+
|
| 43 |
+
return PromptsResponse(
|
| 44 |
+
chinese_prompt=get_prompt("zh"),
|
| 45 |
+
english_prompt=get_prompt("en"),
|
| 46 |
+
is_custom=is_custom
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@router.get("/settings/prompts/default", response_model=DefaultPromptsResponse, summary="获取默认 Prompt")
|
| 51 |
+
async def get_default_prompts():
|
| 52 |
+
"""获取系统默认的 Prompt"""
|
| 53 |
+
return DefaultPromptsResponse(
|
| 54 |
+
chinese_prompt=CHINESE_SYSTEM_PROMPT,
|
| 55 |
+
english_prompt=ENGLISH_SYSTEM_PROMPT
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@router.post("/settings/prompts", summary="更新并保存用户的 Prompt 设置")
|
| 60 |
+
async def update_user_prompts(request: UpdatePromptsRequest):
|
| 61 |
+
"""
|
| 62 |
+
更新用户自定义的 Prompt
|
| 63 |
+
只更新请求体中提供的字段,未提供的字段将保持不变
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
# 获取当前用户配置
|
| 67 |
+
current_prompts = get_user_prompts()
|
| 68 |
+
|
| 69 |
+
# 构建更新数据
|
| 70 |
+
update_data = request.model_dump(exclude_unset=True)
|
| 71 |
+
|
| 72 |
+
if not update_data:
|
| 73 |
+
raise HTTPException(status_code=400, detail="请求体不能为空")
|
| 74 |
+
|
| 75 |
+
# 更新配置
|
| 76 |
+
current_prompts.update(update_data)
|
| 77 |
+
|
| 78 |
+
# 保存配置
|
| 79 |
+
if not save_user_prompts(current_prompts):
|
| 80 |
+
raise HTTPException(status_code=500, detail="保存配置失败")
|
| 81 |
+
|
| 82 |
+
return {"message": "用户 Prompt 更新成功", "updated_fields": list(update_data.keys())}
|
| 83 |
+
|
| 84 |
+
except HTTPException:
|
| 85 |
+
raise
|
| 86 |
+
except Exception as e:
|
| 87 |
+
raise HTTPException(status_code=500, detail=f"更新 Prompt 失败: {str(e)}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.delete("/settings/prompts", summary="重置 Prompt 为默认值")
|
| 91 |
+
async def reset_prompts():
|
| 92 |
+
"""重置用户自定义的 Prompt 为系统默认值"""
|
| 93 |
+
try:
|
| 94 |
+
if not reset_prompts_to_default():
|
| 95 |
+
raise HTTPException(status_code=500, detail="重置失败")
|
| 96 |
+
|
| 97 |
+
return {"message": "Prompt 已重置为默认值"}
|
| 98 |
+
|
| 99 |
+
except HTTPException:
|
| 100 |
+
raise
|
| 101 |
+
except Exception as e:
|
| 102 |
+
raise HTTPException(status_code=500, detail=f"重置 Prompt 失败: {str(e)}")
|
src/voice_dialogue/config/paths.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
@@ -26,6 +27,27 @@ AUDIO_RESOURCES_PATH = ASSETS_PATH / "audio"
|
|
| 26 |
FRONTEND_ASSETS_PATH = ASSETS_PATH / "www"
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def load_third_party():
|
| 30 |
# 添加第三方库到 Python 路径
|
| 31 |
if THIRD_PARTY_PATH.exists() and str(THIRD_PARTY_PATH) not in sys.path:
|
|
|
|
| 1 |
+
import os
|
| 2 |
import sys
|
| 3 |
from pathlib import Path
|
| 4 |
|
|
|
|
| 27 |
FRONTEND_ASSETS_PATH = ASSETS_PATH / "www"
|
| 28 |
|
| 29 |
|
| 30 |
+
# 用户数据路径 - 根据操作系统选择合适的目录
|
| 31 |
+
def get_app_data_path() -> Path:
|
| 32 |
+
"""获取应用数据存储路径"""
|
| 33 |
+
app_name = "Voice Dialogue"
|
| 34 |
+
|
| 35 |
+
if sys.platform == "darwin": # macOS
|
| 36 |
+
base_path = Path.home() / "Library" / "Application Support"
|
| 37 |
+
elif sys.platform == "win32": # Windows
|
| 38 |
+
base_path = Path(os.environ.get("APPDATA", Path.home() / "AppData" / "Roaming"))
|
| 39 |
+
else: # Linux and others
|
| 40 |
+
base_path = Path.home() / ".config"
|
| 41 |
+
|
| 42 |
+
return base_path / app_name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
APP_DATA_PATH = get_app_data_path()
|
| 46 |
+
if not APP_DATA_PATH.exists():
|
| 47 |
+
APP_DATA_PATH.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
USER_PROMPTS_PATH = APP_DATA_PATH / "user_prompts.json"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
def load_third_party():
|
| 52 |
# 添加第三方库到 Python 路径
|
| 53 |
if THIRD_PARTY_PATH.exists() and str(THIRD_PARTY_PATH) not in sys.path:
|
src/voice_dialogue/config/user_config.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""用户配置管理模块"""
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from voice_dialogue.utils.logger import logger
|
| 6 |
+
from .llm_config import CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
|
| 7 |
+
from .paths import USER_PROMPTS_PATH
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_user_prompts() -> Dict[str, str]:
|
| 11 |
+
"""
|
| 12 |
+
加载用户自定义的 prompts
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dict[str, str]: 用户自定义的 prompts,如果文件不存在或解析失败则返回空字典
|
| 16 |
+
"""
|
| 17 |
+
if not USER_PROMPTS_PATH.exists():
|
| 18 |
+
logger.info(f"用户配置文件不存在: {USER_PROMPTS_PATH}")
|
| 19 |
+
return {}
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
with open(USER_PROMPTS_PATH, 'r', encoding='utf-8') as f:
|
| 23 |
+
user_prompts = json.load(f)
|
| 24 |
+
logger.info("成功加载用户自定义 prompts")
|
| 25 |
+
return user_prompts
|
| 26 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 27 |
+
logger.error(f"无法加载用户 prompt 配置文件: {e}")
|
| 28 |
+
return {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_user_prompts(prompts: Dict[str, str]) -> bool:
|
| 32 |
+
"""
|
| 33 |
+
保存用户自定义的 prompts 到 JSON 文件
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
prompts: 要保存的 prompts 字典
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
bool: 保存是否成功
|
| 40 |
+
"""
|
| 41 |
+
try:
|
| 42 |
+
# 确保目录存在
|
| 43 |
+
if not USER_PROMPTS_PATH.parent.exists():
|
| 44 |
+
USER_PROMPTS_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
with open(USER_PROMPTS_PATH, 'w', encoding='utf-8') as f:
|
| 47 |
+
json.dump(prompts, f, ensure_ascii=False, indent=4)
|
| 48 |
+
logger.info(f"用户 prompts 已保存到: {USER_PROMPTS_PATH}")
|
| 49 |
+
return True
|
| 50 |
+
except IOError as e:
|
| 51 |
+
logger.error(f"无法保存用户 prompt 配置文件: {e}")
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_prompt(language: str) -> str:
|
| 56 |
+
"""
|
| 57 |
+
获取指定语言的 prompt
|
| 58 |
+
优先从用户配置中获取,如果未配置,则返回默认值
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
language: 语言代码,"zh" 表示中文,其他表示英文
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
str: 对应语言的系统提示词
|
| 65 |
+
"""
|
| 66 |
+
user_prompts = get_user_prompts()
|
| 67 |
+
|
| 68 |
+
if language == "zh":
|
| 69 |
+
return user_prompts.get("chinese_prompt", CHINESE_SYSTEM_PROMPT)
|
| 70 |
+
else:
|
| 71 |
+
return user_prompts.get("english_prompt", ENGLISH_SYSTEM_PROMPT)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def reset_prompts_to_default() -> bool:
|
| 75 |
+
"""
|
| 76 |
+
重置 prompts 为默认值
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
bool: 重置是否成功
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
if USER_PROMPTS_PATH.exists():
|
| 83 |
+
USER_PROMPTS_PATH.unlink()
|
| 84 |
+
logger.info("用户自定义 prompts 已重置为默认值")
|
| 85 |
+
return True
|
| 86 |
+
except IOError as e:
|
| 87 |
+
logger.error(f"重置 prompts 失败: {e}")
|
| 88 |
+
return False
|
src/voice_dialogue/services/text/generator.py
CHANGED
|
@@ -9,6 +9,7 @@ from voice_dialogue.config import paths
|
|
| 9 |
from voice_dialogue.config.llm_config import (
|
| 10 |
get_llm_model_params, get_apple_silicon_summary, CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
|
| 11 |
)
|
|
|
|
| 12 |
from voice_dialogue.core.base import BaseThread
|
| 13 |
from voice_dialogue.core.constants import chat_history_cache
|
| 14 |
from voice_dialogue.models.voice_task import VoiceTask, QuestionDisplayMessage
|
|
@@ -38,10 +39,7 @@ class LLMResponseGenerator(BaseThread):
|
|
| 38 |
|
| 39 |
def _get_prompt_by_language(self, language: str) -> str:
|
| 40 |
"""根据语言获取对应的 prompt"""
|
| 41 |
-
|
| 42 |
-
return CHINESE_SYSTEM_PROMPT
|
| 43 |
-
else:
|
| 44 |
-
return ENGLISH_SYSTEM_PROMPT
|
| 45 |
|
| 46 |
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
| 47 |
message_history = InMemoryChatMessageHistory()
|
|
@@ -193,7 +191,9 @@ class LLMResponseGenerator(BaseThread):
|
|
| 193 |
self.model_instance = create_langchain_chat_llamacpp_instance(
|
| 194 |
local_model_path=model_path, model_params=model_params
|
| 195 |
)
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
warmup_langchain_pipeline(pipeline)
|
| 198 |
|
| 199 |
self.is_ready = True
|
|
|
|
| 9 |
from voice_dialogue.config.llm_config import (
|
| 10 |
get_llm_model_params, get_apple_silicon_summary, CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
|
| 11 |
)
|
| 12 |
+
from voice_dialogue.config.user_config import get_prompt # 修改导入
|
| 13 |
from voice_dialogue.core.base import BaseThread
|
| 14 |
from voice_dialogue.core.constants import chat_history_cache
|
| 15 |
from voice_dialogue.models.voice_task import VoiceTask, QuestionDisplayMessage
|
|
|
|
| 39 |
|
| 40 |
def _get_prompt_by_language(self, language: str) -> str:
|
| 41 |
"""根据语言获取对应的 prompt"""
|
| 42 |
+
return get_prompt(language)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
| 45 |
message_history = InMemoryChatMessageHistory()
|
|
|
|
| 191 |
self.model_instance = create_langchain_chat_llamacpp_instance(
|
| 192 |
local_model_path=model_path, model_params=model_params
|
| 193 |
)
|
| 194 |
+
# 使用默认中文 prompt 进行 warmup
|
| 195 |
+
prompt = get_prompt("zh")
|
| 196 |
+
pipeline = create_langchain_pipeline(self.model_instance, prompt, self.get_session_history)
|
| 197 |
warmup_langchain_pipeline(pipeline)
|
| 198 |
|
| 199 |
self.is_ready = True
|