liumaolin commited on
Commit
15ab478
·
1 Parent(s): 1816130

添加设置相关的API路由

Browse files

- 支持用户自定义提示词的获取、更新和重置功能
- 同时更新应用数据存储路径的获取逻辑。

src/voice_dialogue/api/app.py CHANGED
@@ -11,7 +11,7 @@ from .core.config import AppConfig
11
  from .core.lifespan import lifespan
12
  from .middleware.logging import LoggingMiddleware
13
  from .middleware.rate_limit import RateLimitMiddleware
14
- from .routes import tts_routes, asr_routes, system_routes, websocket_routes
15
 
16
 
17
  def create_app() -> FastAPI:
@@ -56,6 +56,7 @@ def _register_routes(app: FastAPI):
56
  v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
57
  v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
58
  v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
 
59
  app.include_router(v1_router)
60
 
61
  app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
 
11
  from .core.lifespan import lifespan
12
  from .middleware.logging import LoggingMiddleware
13
  from .middleware.rate_limit import RateLimitMiddleware
14
+ from .routes import tts_routes, asr_routes, system_routes, websocket_routes, settings_routes
15
 
16
 
17
  def create_app() -> FastAPI:
 
56
  v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
57
  v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
58
  v1_router.include_router(system_routes.router, prefix="/system", tags=["系统管理"])
59
+ v1_router.include_router(settings_routes.router, prefix="/settings", tags=["设置管理"])
60
  app.include_router(v1_router)
61
 
62
  app.add_websocket_route("/api/v1/ws", websocket_routes.ws)
src/voice_dialogue/api/routes/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from . import tts_routes, asr_routes, system_routes, websocket_routes
2
 
3
- __all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes"]
 
1
+ from . import tts_routes, asr_routes, system_routes, websocket_routes, settings_routes
2
 
3
+ __all__ = ["tts_routes", "asr_routes", "system_routes", "websocket_routes", "settings_routes"]
src/voice_dialogue/api/routes/settings_routes.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """设置相关的API路由"""
2
+ from typing import Optional
3
+
4
+ from fastapi import APIRouter, HTTPException
5
+ from pydantic import BaseModel, Field
6
+
7
+ from voice_dialogue.config.llm_config import CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
8
+ from voice_dialogue.config.user_config import (
9
+ get_user_prompts, save_user_prompts, get_prompt, reset_prompts_to_default
10
+ )
11
+
12
+ router = APIRouter()
13
+
14
+
15
+ class PromptsResponse(BaseModel):
16
+ """获取 Prompts 的响应模型"""
17
+ chinese_prompt: str = Field(..., description="中文系统提示词")
18
+ english_prompt: str = Field(..., description="英文系统提示词")
19
+ is_custom: bool = Field(..., description="是否为用户自定义")
20
+
21
+
22
+ class UpdatePromptsRequest(BaseModel):
23
+ """更新 Prompts 的请求模型"""
24
+ chinese_prompt: Optional[str] = Field(None, description="中文系统提示词")
25
+ english_prompt: Optional[str] = Field(None, description="英文系统提示词")
26
+
27
+
28
+ class DefaultPromptsResponse(BaseModel):
29
+ """默认 Prompts 的响应模型"""
30
+ chinese_prompt: str = Field(..., description="默认中文系统提示词")
31
+ english_prompt: str = Field(..., description="默认英文系统提示词")
32
+
33
+
34
+ @router.get("/settings/prompts", response_model=PromptsResponse, summary="获取当前生效的 Prompt")
35
+ async def get_current_prompts():
36
+ """
37
+ 获取当前系统中正在使用的中文和英文系统 Prompt
38
+ 会融合用户自定义设置和系统默认值
39
+ """
40
+ user_prompts = get_user_prompts()
41
+ is_custom = bool(user_prompts) # 如果有用户自定义配置,则认为是自定义的
42
+
43
+ return PromptsResponse(
44
+ chinese_prompt=get_prompt("zh"),
45
+ english_prompt=get_prompt("en"),
46
+ is_custom=is_custom
47
+ )
48
+
49
+
50
+ @router.get("/settings/prompts/default", response_model=DefaultPromptsResponse, summary="获取默认 Prompt")
51
+ async def get_default_prompts():
52
+ """获取系统默认的 Prompt"""
53
+ return DefaultPromptsResponse(
54
+ chinese_prompt=CHINESE_SYSTEM_PROMPT,
55
+ english_prompt=ENGLISH_SYSTEM_PROMPT
56
+ )
57
+
58
+
59
+ @router.post("/settings/prompts", summary="更新并保存用户的 Prompt 设置")
60
+ async def update_user_prompts(request: UpdatePromptsRequest):
61
+ """
62
+ 更新用户自定义的 Prompt
63
+ 只更新请求体中提供的字段,未提供的字段将保持不变
64
+ """
65
+ try:
66
+ # 获取当前用户配置
67
+ current_prompts = get_user_prompts()
68
+
69
+ # 构建更新数据
70
+ update_data = request.model_dump(exclude_unset=True)
71
+
72
+ if not update_data:
73
+ raise HTTPException(status_code=400, detail="请求体不能为空")
74
+
75
+ # 更新配置
76
+ current_prompts.update(update_data)
77
+
78
+ # 保存配置
79
+ if not save_user_prompts(current_prompts):
80
+ raise HTTPException(status_code=500, detail="保存配置失败")
81
+
82
+ return {"message": "用户 Prompt 更新成功", "updated_fields": list(update_data.keys())}
83
+
84
+ except HTTPException:
85
+ raise
86
+ except Exception as e:
87
+ raise HTTPException(status_code=500, detail=f"更新 Prompt 失败: {str(e)}")
88
+
89
+
90
+ @router.delete("/settings/prompts", summary="重置 Prompt 为默认值")
91
+ async def reset_prompts():
92
+ """重置用户自定义的 Prompt 为系统默认值"""
93
+ try:
94
+ if not reset_prompts_to_default():
95
+ raise HTTPException(status_code=500, detail="重置失败")
96
+
97
+ return {"message": "Prompt 已重置为默认值"}
98
+
99
+ except HTTPException:
100
+ raise
101
+ except Exception as e:
102
+ raise HTTPException(status_code=500, detail=f"重置 Prompt 失败: {str(e)}")
src/voice_dialogue/config/paths.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sys
2
  from pathlib import Path
3
 
@@ -26,6 +27,27 @@ AUDIO_RESOURCES_PATH = ASSETS_PATH / "audio"
26
  FRONTEND_ASSETS_PATH = ASSETS_PATH / "www"
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def load_third_party():
30
  # 添加第三方库到 Python 路径
31
  if THIRD_PARTY_PATH.exists() and str(THIRD_PARTY_PATH) not in sys.path:
 
1
+ import os
2
  import sys
3
  from pathlib import Path
4
 
 
27
  FRONTEND_ASSETS_PATH = ASSETS_PATH / "www"
28
 
29
 
30
+ # 用户数据路径 - 根据操作系统选择合适的目录
31
+ def get_app_data_path() -> Path:
32
+ """获取应用数据存储路径"""
33
+ app_name = "Voice Dialogue"
34
+
35
+ if sys.platform == "darwin": # macOS
36
+ base_path = Path.home() / "Library" / "Application Support"
37
+ elif sys.platform == "win32": # Windows
38
+ base_path = Path(os.environ.get("APPDATA", Path.home() / "AppData" / "Roaming"))
39
+ else: # Linux and others
40
+ base_path = Path.home() / ".config"
41
+
42
+ return base_path / app_name
43
+
44
+
45
+ APP_DATA_PATH = get_app_data_path()
46
+ if not APP_DATA_PATH.exists():
47
+ APP_DATA_PATH.mkdir(parents=True, exist_ok=True)
48
+ USER_PROMPTS_PATH = APP_DATA_PATH / "user_prompts.json"
49
+
50
+
51
  def load_third_party():
52
  # 添加第三方库到 Python 路径
53
  if THIRD_PARTY_PATH.exists() and str(THIRD_PARTY_PATH) not in sys.path:
src/voice_dialogue/config/user_config.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """用户配置管理模块"""
2
+ import json
3
+ from typing import Dict
4
+
5
+ from voice_dialogue.utils.logger import logger
6
+ from .llm_config import CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
7
+ from .paths import USER_PROMPTS_PATH
8
+
9
+
10
+ def get_user_prompts() -> Dict[str, str]:
11
+ """
12
+ 加载用户自定义的 prompts
13
+
14
+ Returns:
15
+ Dict[str, str]: 用户自定义的 prompts,如果文件不存在或解析失败则返回空字典
16
+ """
17
+ if not USER_PROMPTS_PATH.exists():
18
+ logger.info(f"用户配置文件不存在: {USER_PROMPTS_PATH}")
19
+ return {}
20
+
21
+ try:
22
+ with open(USER_PROMPTS_PATH, 'r', encoding='utf-8') as f:
23
+ user_prompts = json.load(f)
24
+ logger.info("成功加载用户自定义 prompts")
25
+ return user_prompts
26
+ except (json.JSONDecodeError, IOError) as e:
27
+ logger.error(f"无法加载用户 prompt 配置文件: {e}")
28
+ return {}
29
+
30
+
31
+ def save_user_prompts(prompts: Dict[str, str]) -> bool:
32
+ """
33
+ 保存用户自定义的 prompts 到 JSON 文件
34
+
35
+ Args:
36
+ prompts: 要保存的 prompts 字典
37
+
38
+ Returns:
39
+ bool: 保存是否成功
40
+ """
41
+ try:
42
+ # 确保目录存在
43
+ if not USER_PROMPTS_PATH.parent.exists():
44
+ USER_PROMPTS_PATH.parent.mkdir(parents=True, exist_ok=True)
45
+
46
+ with open(USER_PROMPTS_PATH, 'w', encoding='utf-8') as f:
47
+ json.dump(prompts, f, ensure_ascii=False, indent=4)
48
+ logger.info(f"用户 prompts 已保存到: {USER_PROMPTS_PATH}")
49
+ return True
50
+ except IOError as e:
51
+ logger.error(f"无法保存用户 prompt 配置文件: {e}")
52
+ return False
53
+
54
+
55
+ def get_prompt(language: str) -> str:
56
+ """
57
+ 获取指定语言的 prompt
58
+ 优先从用户配置中获取,如果未配置,则返回默认值
59
+
60
+ Args:
61
+ language: 语言代码,"zh" 表示中文,其他表示英文
62
+
63
+ Returns:
64
+ str: 对应语言的系统提示词
65
+ """
66
+ user_prompts = get_user_prompts()
67
+
68
+ if language == "zh":
69
+ return user_prompts.get("chinese_prompt", CHINESE_SYSTEM_PROMPT)
70
+ else:
71
+ return user_prompts.get("english_prompt", ENGLISH_SYSTEM_PROMPT)
72
+
73
+
74
+ def reset_prompts_to_default() -> bool:
75
+ """
76
+ 重置 prompts 为默认值
77
+
78
+ Returns:
79
+ bool: 重置是否成功
80
+ """
81
+ try:
82
+ if USER_PROMPTS_PATH.exists():
83
+ USER_PROMPTS_PATH.unlink()
84
+ logger.info("用户自定义 prompts 已重置为默认值")
85
+ return True
86
+ except IOError as e:
87
+ logger.error(f"重置 prompts 失败: {e}")
88
+ return False
src/voice_dialogue/services/text/generator.py CHANGED
@@ -9,6 +9,7 @@ from voice_dialogue.config import paths
9
  from voice_dialogue.config.llm_config import (
10
  get_llm_model_params, get_apple_silicon_summary, CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
11
  )
 
12
  from voice_dialogue.core.base import BaseThread
13
  from voice_dialogue.core.constants import chat_history_cache
14
  from voice_dialogue.models.voice_task import VoiceTask, QuestionDisplayMessage
@@ -38,10 +39,7 @@ class LLMResponseGenerator(BaseThread):
38
 
39
  def _get_prompt_by_language(self, language: str) -> str:
40
  """根据语言获取对应的 prompt"""
41
- if language == "zh":
42
- return CHINESE_SYSTEM_PROMPT
43
- else:
44
- return ENGLISH_SYSTEM_PROMPT
45
 
46
  def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
47
  message_history = InMemoryChatMessageHistory()
@@ -193,7 +191,9 @@ class LLMResponseGenerator(BaseThread):
193
  self.model_instance = create_langchain_chat_llamacpp_instance(
194
  local_model_path=model_path, model_params=model_params
195
  )
196
- pipeline = create_langchain_pipeline(self.model_instance, CHINESE_SYSTEM_PROMPT, self.get_session_history)
 
 
197
  warmup_langchain_pipeline(pipeline)
198
 
199
  self.is_ready = True
 
9
  from voice_dialogue.config.llm_config import (
10
  get_llm_model_params, get_apple_silicon_summary, CHINESE_SYSTEM_PROMPT, ENGLISH_SYSTEM_PROMPT
11
  )
12
+ from voice_dialogue.config.user_config import get_prompt # 修改导入
13
  from voice_dialogue.core.base import BaseThread
14
  from voice_dialogue.core.constants import chat_history_cache
15
  from voice_dialogue.models.voice_task import VoiceTask, QuestionDisplayMessage
 
39
 
40
  def _get_prompt_by_language(self, language: str) -> str:
41
  """根据语言获取对应的 prompt"""
42
+ return get_prompt(language)
 
 
 
43
 
44
  def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
45
  message_history = InMemoryChatMessageHistory()
 
191
  self.model_instance = create_langchain_chat_llamacpp_instance(
192
  local_model_path=model_path, model_params=model_params
193
  )
194
+ # 使用默认中文 prompt 进行 warmup
195
+ prompt = get_prompt("zh")
196
+ pipeline = create_langchain_pipeline(self.model_instance, prompt, self.get_session_history)
197
  warmup_langchain_pipeline(pipeline)
198
 
199
  self.is_ready = True