liumaolin commited on
Commit
8f823b0
·
1 Parent(s): d91a26b

Introduce initial API structure for VoiceDialogue: add dependencies, middleware, and routes for ASR, TTS, system, and voice modules.

Browse files
src/VoiceDialogue/api/__init__.py ADDED
File without changes
src/VoiceDialogue/api/app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import asynccontextmanager
3
+ from typing import Dict, Any
4
+
5
+ from fastapi import FastAPI, HTTPException, APIRouter
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+
8
+ from .middleware.logging import LoggingMiddleware
9
+ from .middleware.rate_limit import RateLimitMiddleware
10
+ from .routes import tts_routes, asr_routes
11
+
12
+ # 配置日志
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # 全局状态存储
20
+ app_state: Dict[str, Any] = {}
21
+
22
+
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ """应用启动和关闭的生命周期管理"""
26
+ # 启动时的初始化
27
+ logger.info("正在启动VoiceDialogue API服务...")
28
+
29
+ # 初始化TTS配置注册表
30
+ try:
31
+ from services.audio.audio_generator import tts_config_registry
32
+ logger.info(f"已加载 {len(tts_config_registry.get_all_configs())} 个TTS配置")
33
+ app_state["tts_configs_loaded"] = True
34
+ except Exception as e:
35
+ logger.error(f"TTS配置加载失败: {e}")
36
+ app_state["tts_configs_loaded"] = False
37
+
38
+ app_state["system_running"] = True
39
+ logger.info("VoiceDialogue API服务启动完成")
40
+ yield
41
+
42
+ # 关闭时的清理
43
+ logger.info("正在关闭VoiceDialogue API服务...")
44
+ app_state["system_running"] = False
45
+ logger.info("VoiceDialogue API服务已关闭")
46
+
47
+
48
+ # 创建FastAPI应用
49
+ app = FastAPI(
50
+ title="VoiceDialogue API",
51
+ description="""
52
+ 语音对话系统的HTTP API接口
53
+
54
+ ## 功能特性
55
+
56
+ * **TTS模型管理**: 查看、加载、删除TTS模型
57
+ * **模型状态监控**: 实时监控模型下载和加载状态
58
+ * **RESTful API**: 标准的REST接口设计
59
+ * **自动文档**: 自动生成的API文档和测试界面
60
+
61
+ ## 使用方法
62
+
63
+ 1. 查看所有可用的TTS模型: `GET /api/v1/tts/models`
64
+ 2. 加载指定模型: `POST /api/v1/tts/models/load`
65
+ 3. 查看模型状态: `GET /api/v1/tts/models/{model_id}/status`
66
+ 4. 删除模型: `DELETE /api/v1/tts/models/{model_id}`
67
+ """,
68
+ version="1.0.0",
69
+ docs_url="/docs",
70
+ redoc_url="/redoc",
71
+ lifespan=lifespan,
72
+ )
73
+
74
+ # 添加CORS中间件
75
+ app.add_middleware(
76
+ CORSMiddleware,
77
+ allow_origins=["*"], # 生产环境中应该设置具体的域名
78
+ allow_credentials=True,
79
+ allow_methods=["*"],
80
+ allow_headers=["*"],
81
+ )
82
+
83
+ # 添加自定义中间件
84
+ app.add_middleware(LoggingMiddleware)
85
+ app.add_middleware(RateLimitMiddleware)
86
+
87
+ # 注册路由
88
+ v1_router = APIRouter(prefix="/api/v1")
89
+ # v1_router.include_router(voice_routes.router, prefix="/voice", tags=["语音处理"])
90
+ # v1_router.include_router(system_routes.router, prefix="/system", tags=["系统控制"])
91
+ v1_router.include_router(tts_routes.router, prefix="/tts", tags=["TTS模型管理"])
92
+ v1_router.include_router(asr_routes.router, prefix="/asr", tags=["ASR模型管理"])
93
+
94
+ app.include_router(v1_router)
95
+
96
+
97
+ @app.get("/")
98
+ async def root():
99
+ """根路径健康检查"""
100
+ return {
101
+ "message": "欢迎使用VoiceDialogue API",
102
+ "status": "running",
103
+ "version": "1.0.0",
104
+ "docs_url": "/docs",
105
+ "redoc_url": "/redoc"
106
+ }
107
+
108
+
109
+ @app.get("/health")
110
+ async def health_check():
111
+ """健康检查端点"""
112
+ return {
113
+ "status": "healthy",
114
+ "tts_configs_loaded": app_state.get("tts_configs_loaded", False),
115
+ "system_running": app_state.get("system_running", False),
116
+ "available_models": len(app_state.get("available_models", []))
117
+ }
118
+
119
+
120
+ # 全局异常处理器
121
+ @app.exception_handler(Exception)
122
+ async def global_exception_handler(request, exc):
123
+ logger.error(f"未处理的异常: {exc}", exc_info=True)
124
+ return HTTPException(
125
+ status_code=500,
126
+ detail="内部服务器错误"
127
+ )
src/VoiceDialogue/api/dependencies/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .audio_deps import decode_audio_data, encode_audio_data, validate_audio_format
2
+ from .model_deps import get_language_model, get_voice_model, ensure_model_loaded
3
+
4
+ __all__ = [
5
+ "decode_audio_data", "encode_audio_data", "validate_audio_format",
6
+ "get_language_model", "get_voice_model", "ensure_model_loaded"
7
+ ]
src/VoiceDialogue/api/dependencies/audio_deps.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
+ import numpy as np
4
+ from fastapi import HTTPException, Depends
5
+
6
+
7
+ def decode_audio_data(audio_data: str) -> np.ndarray:
8
+ """解码Base64音频数据"""
9
+ try:
10
+ # 解码Base64数据
11
+ decoded_data = base64.b64decode(audio_data)
12
+
13
+ # 转换为numpy数组 (假设是16-bit PCM格式)
14
+ audio_array = np.frombuffer(decoded_data, dtype=np.int16)
15
+
16
+ # 转换为float32格式,范围[-1, 1]
17
+ audio_array = audio_array.astype(np.float32) / 32768.0
18
+
19
+ return audio_array
20
+ except Exception as e:
21
+ raise HTTPException(
22
+ status_code=400,
23
+ detail=f"音频数据解码失败: {str(e)}"
24
+ )
25
+
26
+
27
+ def encode_audio_data(audio_array: np.ndarray, sample_rate: int = 16000) -> str:
28
+ """编码音频数据为Base64"""
29
+ try:
30
+ # 转换为16-bit PCM格式
31
+ audio_int16 = (audio_array * 32767).astype(np.int16)
32
+
33
+ # 转换为字节
34
+ audio_bytes = audio_int16.tobytes()
35
+
36
+ # Base64编码
37
+ encoded_data = base64.b64encode(audio_bytes).decode('utf-8')
38
+
39
+ return encoded_data
40
+ except Exception as e:
41
+ raise HTTPException(
42
+ status_code=500,
43
+ detail=f"音频数据编码失败: {str(e)}"
44
+ )
45
+
46
+
47
+ def validate_audio_format(audio_array: np.ndarray) -> bool:
48
+ """验证音频格式"""
49
+ if len(audio_array) == 0:
50
+ raise HTTPException(
51
+ status_code=400,
52
+ detail="音频数据为空"
53
+ )
54
+
55
+ if len(audio_array) > 16000 * 30: # 30秒限制
56
+ raise HTTPException(
57
+ status_code=400,
58
+ detail="音频时长超过30秒限制"
59
+ )
60
+
61
+ return True
src/VoiceDialogue/api/dependencies/model_deps.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Dict, Any
3
+
4
+ from fastapi import HTTPException, Depends
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ # 模拟的全局模型状态
9
+ _loaded_models: Dict[str, Any] = {}
10
+
11
+
12
+ def get_language_model(model_name: Optional[str] = None):
13
+ """获取语言模型依赖"""
14
+ try:
15
+ # 这里应该从实际的模型注册表中获取
16
+ from ...models.language_model import language_model_registry
17
+
18
+ if model_name:
19
+ # 根据名称查找特定模型
20
+ for model in language_model_registry:
21
+ if model.name == model_name:
22
+ return model
23
+ raise HTTPException(
24
+ status_code=404,
25
+ detail=f"未找到名为 {model_name} 的语言模型"
26
+ )
27
+ else:
28
+ # 返回默认模型 (14B)
29
+ return language_model_registry[-2]
30
+ except ImportError:
31
+ raise HTTPException(
32
+ status_code=500,
33
+ detail="语言模型模块导入失败"
34
+ )
35
+
36
+
37
+ def get_voice_model(speaker_name: str = "沈逸"):
38
+ """获取语音模型依赖"""
39
+ try:
40
+ from services.audio.audio_generator.voice_model import voice_model_registry
41
+
42
+ speaker_mapping = {
43
+ '罗翔': 0,
44
+ '马保国': 1,
45
+ '沈逸': 2,
46
+ '杨幂': 3,
47
+ '周杰伦': 4,
48
+ '马云': 5,
49
+ }
50
+
51
+ index = speaker_mapping.get(speaker_name, 2) # 默认沈逸
52
+
53
+ if index < len(voice_model_registry):
54
+ return voice_model_registry[index]
55
+ else:
56
+ raise HTTPException(
57
+ status_code=404,
58
+ detail=f"未找到语音角色: {speaker_name}"
59
+ )
60
+ except ImportError:
61
+ raise HTTPException(
62
+ status_code=500,
63
+ detail="语音模型模块导入失败"
64
+ )
65
+
66
+
67
+ def ensure_model_loaded(model):
68
+ """确保模型已加载"""
69
+ try:
70
+ if not hasattr(model, 'is_loaded') or not model.is_loaded:
71
+ model.download_model()
72
+ _loaded_models[model.name] = model
73
+ return model
74
+ except Exception as e:
75
+ logger.error(f"模型加载失败: {e}")
76
+ raise HTTPException(
77
+ status_code=500,
78
+ detail=f"模型加载失败: {str(e)}"
79
+ )
src/VoiceDialogue/api/middleware/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .logging import LoggingMiddleware
2
+ from .rate_limit import RateLimitMiddleware
3
+
4
+ __all__ = ["LoggingMiddleware", "RateLimitMiddleware"]
src/VoiceDialogue/api/middleware/logging.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+ from fastapi import Request, Response
5
+ from starlette.middleware.base import BaseHTTPMiddleware
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class LoggingMiddleware(BaseHTTPMiddleware):
11
+ """请求日志中间件"""
12
+
13
+ async def dispatch(self, request: Request, call_next):
14
+ start_time = time.time()
15
+
16
+ # 记录请求信息
17
+ logger.info(
18
+ f"请求开始: {request.method} {request.url.path} "
19
+ f"客户端: {request.client.host if request.client else 'unknown'}"
20
+ )
21
+
22
+ # 处理请求
23
+ response = await call_next(request)
24
+
25
+ # 计算处理时间
26
+ process_time = time.time() - start_time
27
+
28
+ # 记录响应信息
29
+ logger.info(
30
+ f"请求完成: {request.method} {request.url.path} "
31
+ f"状态码: {response.status_code} "
32
+ f"处理时间: {process_time:.3f}s"
33
+ )
34
+
35
+ # 添加处理时间到响应头
36
+ response.headers["X-Process-Time"] = str(process_time)
37
+
38
+ return response
src/VoiceDialogue/api/middleware/rate_limit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import defaultdict
3
+
4
+ from fastapi import Request, HTTPException
5
+ from starlette.middleware.base import BaseHTTPMiddleware
6
+
7
+
8
+ class RateLimitMiddleware(BaseHTTPMiddleware):
9
+ """API限流中间件"""
10
+
11
+ def __init__(self, app, calls_per_minute: int = 60):
12
+ super().__init__(app)
13
+ self.calls_per_minute = calls_per_minute
14
+ self.calls = defaultdict(list)
15
+
16
+ async def dispatch(self, request: Request, call_next):
17
+ client_ip = request.client.host if request.client else "unknown"
18
+ current_time = time.time()
19
+
20
+ # 清理过期的调用记录
21
+ minute_ago = current_time - 60
22
+ self.calls[client_ip] = [
23
+ call_time for call_time in self.calls[client_ip]
24
+ if call_time > minute_ago
25
+ ]
26
+
27
+ # 检查是否超过限制
28
+ if len(self.calls[client_ip]) >= self.calls_per_minute:
29
+ raise HTTPException(
30
+ status_code=429,
31
+ detail=f"请求频率过高,每分钟最多允许 {self.calls_per_minute} 次请求"
32
+ )
33
+
34
+ # 记录本次调用
35
+ self.calls[client_ip].append(current_time)
36
+
37
+ # 处理请求
38
+ response = await call_next(request)
39
+ return response
src/VoiceDialogue/api/routes/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import tts_routes, asr_routes
2
+
3
+ __all__ = ["tts_routes", "asr_routes"]
src/VoiceDialogue/api/routes/asr_routes.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+
5
+ from services.speech.asr import asr_manager
6
+ from ..schemas.asr_schemas import (
7
+ SupportedLanguagesResponse, ASRInstanceRequest, ASRInstanceResponse, LanguageMappingRequest,
8
+ LanguageMappingResponse, ASRValidationRequest, ASRValidationResponse, CleanupResponse
9
+ )
10
+
11
+ logger = logging.getLogger(__name__)
12
+ router = APIRouter()
13
+
14
+
15
+ @router.get("/languages", response_model=SupportedLanguagesResponse, summary="获取支持的识别语言")
16
+ async def get_supported_languages():
17
+ """
18
+ 获取系统支持的语音识别语言列表,包括语言映射和可用引擎
19
+ """
20
+ try:
21
+ available_languages = asr_manager.get_available_languages()
22
+ language_mappings = asr_manager._language_to_asr_mapping
23
+ asr_engines = list(asr_manager.list_registered_asr().keys())
24
+
25
+ return SupportedLanguagesResponse(
26
+ languages=available_languages,
27
+ language_mappings=language_mappings,
28
+ asr_engines=asr_engines
29
+ )
30
+ except Exception as e:
31
+ logger.error(f"获取支持语言列表失败: {e}", exc_info=True)
32
+ raise HTTPException(status_code=500, detail=f"获取支持语言列表失败: {str(e)}")
33
+
34
+
35
+ @router.post("/instance/create", response_model=ASRInstanceResponse, summary="创建ASR实例")
36
+ async def create_asr_instance(request: ASRInstanceRequest):
37
+ """
38
+ 根据指定语言创建新的ASR实例
39
+ """
40
+ try:
41
+ # 获取最优的ASR引擎
42
+ asr_type = asr_manager._get_asr_type_for_language(request.language)
43
+
44
+ # 创建实例
45
+ instance = asr_manager.create_asr(request.language)
46
+
47
+ return ASRInstanceResponse(
48
+ success=True,
49
+ message=f"成功创建ASR实例",
50
+ language=request.language,
51
+ asr_type=asr_type,
52
+ instance_id=f"{asr_type}_{request.language}"
53
+ )
54
+ except ValueError as e:
55
+ logger.warning(f"创建ASR实例失败 - 参数错误: {e}")
56
+ raise HTTPException(status_code=400, detail=str(e))
57
+ except Exception as e:
58
+ logger.error(f"创建ASR实例失败: {e}", exc_info=True)
59
+ raise HTTPException(status_code=500, detail=f"创建ASR实例失败: {str(e)}")
60
+
61
+
62
+ @router.post("/instance/get-or-create", response_model=ASRInstanceResponse, summary="获取或创建ASR实例")
63
+ async def get_or_create_asr_instance(request: ASRInstanceRequest):
64
+ """
65
+ 获取现有的ASR实例,如果不存在则创建新实例(单例模式)
66
+ """
67
+ try:
68
+ # 获取最优的ASR引擎
69
+ asr_type = asr_manager._get_asr_type_for_language(request.language)
70
+
71
+ # 获取或创建实例
72
+ instance = asr_manager.get_or_create_asr(request.language)
73
+
74
+ return ASRInstanceResponse(
75
+ success=True,
76
+ message=f"成功获取ASR实例",
77
+ language=request.language,
78
+ asr_type=asr_type,
79
+ instance_id=f"{asr_type}_{request.language}"
80
+ )
81
+ except ValueError as e:
82
+ logger.warning(f"获取ASR实例失败 - 参数错误: {e}")
83
+ raise HTTPException(status_code=400, detail=str(e))
84
+ except Exception as e:
85
+ logger.error(f"获取ASR实例失败: {e}", exc_info=True)
86
+ raise HTTPException(status_code=500, detail=f"获取ASR实例失败: {str(e)}")
87
+
88
+
89
+ @router.post("/mapping", response_model=LanguageMappingResponse, summary="配置语言映射")
90
+ async def set_language_mapping(request: LanguageMappingRequest):
91
+ """
92
+ 设置特定语言使用的ASR引擎
93
+ """
94
+ try:
95
+ asr_manager.set_language_mapping(request.language, request.asr_type)
96
+
97
+ return LanguageMappingResponse(
98
+ success=True,
99
+ message=f"成功设置语言映射: {request.language} -> {request.asr_type}",
100
+ updated_mapping=asr_manager._language_to_asr_mapping.copy()
101
+ )
102
+ except ValueError as e:
103
+ logger.warning(f"设置语言映射失败 - 参数错误: {e}")
104
+ raise HTTPException(status_code=400, detail=str(e))
105
+ except Exception as e:
106
+ logger.error(f"设置语言映射失败: {e}", exc_info=True)
107
+ raise HTTPException(status_code=500, detail=f"设置语言映射失败: {str(e)}")
108
+
109
+
110
+ @router.post("/validate", response_model=ASRValidationResponse, summary="验证语言支持")
111
+ async def validate_language_support(request: ASRValidationRequest):
112
+ """
113
+ 验证指定语言是否被支持,并返回相关信息
114
+ """
115
+ try:
116
+ is_supported = asr_manager.validate_language_support(request.language)
117
+ optimal_asr = asr_manager.get_optimal_asr_for_language(request.language)
118
+
119
+ # 查找支持该语言的所有ASR引擎
120
+ available_asrs = []
121
+ supported_langs = asr_manager.get_supported_languages()
122
+ for asr_key, languages in supported_langs.items():
123
+ if request.language in languages:
124
+ available_asrs.append(asr_key)
125
+
126
+ return ASRValidationResponse(
127
+ language=request.language,
128
+ is_supported=is_supported,
129
+ optimal_asr=optimal_asr,
130
+ available_asrs=available_asrs
131
+ )
132
+ except Exception as e:
133
+ logger.error(f"验证语言支持失败: {e}", exc_info=True)
134
+ raise HTTPException(status_code=500, detail=f"验证语言支持失败: {str(e)}")
135
+
136
+
137
+ @router.get("/validate/{language}", response_model=ASRValidationResponse, summary="验证语言支持(GET)")
138
+ async def validate_language_support_get(language: str):
139
+ """
140
+ 通过GET方法验证指定语言是否被支持
141
+ """
142
+ request = ASRValidationRequest(language=language)
143
+ return await validate_language_support(request)
144
+
145
+
146
+ @router.delete("/cleanup", response_model=CleanupResponse, summary="清理ASR实例")
147
+ async def cleanup_asr_instances():
148
+ """
149
+ 清理所有活动的ASR实例,释放资源
150
+ """
151
+ try:
152
+ # 记录清理前的实例数量
153
+ stats = asr_manager.get_asr_statistics()
154
+ instances_count = stats['active_instances_count']
155
+
156
+ # 执行清理
157
+ asr_manager.cleanup()
158
+
159
+ return CleanupResponse(
160
+ success=True,
161
+ message="成功清理所有ASR实例",
162
+ cleared_instances_count=instances_count
163
+ )
164
+ except Exception as e:
165
+ logger.error(f"清理ASR实例失败: {e}", exc_info=True)
166
+ raise HTTPException(status_code=500, detail=f"清理ASR实例失败: {str(e)}")
167
+
168
+
169
+ @router.get("/health", summary="ASR服务健康检查")
170
+ async def asr_health_check():
171
+ """
172
+ ASR服务的健康检查接口
173
+ """
174
+ try:
175
+ stats = asr_manager.get_asr_statistics()
176
+
177
+ # 检查是否有已注册的ASR引擎
178
+ is_healthy = stats['registered_asr_count'] > 0
179
+
180
+ return {
181
+ "healthy": is_healthy,
182
+ "message": "ASR服务正常" if is_healthy else "没有可用的ASR引擎",
183
+ "registered_engines": stats['registered_asr_count'],
184
+ "active_instances": stats['active_instances_count'],
185
+ "supported_languages": len(stats['supported_languages'])
186
+ }
187
+ except Exception as e:
188
+ logger.error(f"ASR健康检查失败: {e}", exc_info=True)
189
+ return {
190
+ "healthy": False,
191
+ "message": f"ASR服务异常: {str(e)}",
192
+ "error": str(e)
193
+ }
src/VoiceDialogue/api/routes/system_routes.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import time
4
+
5
+ from fastapi import APIRouter, HTTPException, BackgroundTasks
6
+
7
+ from ..schemas.system_schemas import (
8
+ SystemStatusResponse, SystemConfig,
9
+ SystemStartRequest, SystemResponse
10
+ )
11
+
12
+ logger = logging.getLogger(__name__)
13
+ router = APIRouter()
14
+
15
+ # 全局系统状态
16
+ _system_status = {
17
+ "status": "stopped",
18
+ "start_time": None,
19
+ "config": None,
20
+ "active_sessions": 0
21
+ }
22
+
23
+
24
+ # @router.get("/status", response_model=SystemStatusResponse, summary="获取系统状态")
25
+ # async def get_system_status():
26
+ # """
27
+ # 获取系统整体状态,不包含语言模型信息
28
+ # """
29
+ # try:
30
+ # # 获取TTS模型统计
31
+ # all_configs = tts_config_registry.get_all_configs()
32
+ # downloaded_count = sum(1 for config in all_configs if config.is_model_complete())
33
+ #
34
+ # # 获取TTS引擎状态
35
+ # available_engines = list(tts_manager.list_registered_tts().keys())
36
+ #
37
+ # status = SystemStatusResponse(
38
+ # system_status="running",
39
+ # tts_models_total=len(all_configs),
40
+ # tts_models_downloaded=downloaded_count,
41
+ # available_tts_engines=available_engines,
42
+ # memory_usage=_get_memory_usage(),
43
+ # disk_usage=_get_disk_usage()
44
+ # )
45
+ #
46
+ # return status
47
+ #
48
+ # except Exception as e:
49
+ # logger.error(f"获取系统状态失败: {e}", exc_info=True)
50
+ # raise HTTPException(status_code=500, detail=f"获取系统状态失败: {str(e)}")
51
+
52
+
53
+ @router.post("/start", response_model=SystemResponse, summary="启动系统")
54
+ async def start_system(
55
+ request: SystemStartRequest,
56
+ background_tasks: BackgroundTasks
57
+ ):
58
+ """
59
+ 启动语音对话系统
60
+ """
61
+ try:
62
+ if _system_status["status"] in ["running", "starting"]:
63
+ return SystemResponse(
64
+ success=False,
65
+ message="系统已经在运行中或正在启动"
66
+ )
67
+
68
+ # 更新状态
69
+ _system_status["status"] = "starting"
70
+ _system_status["config"] = request.config
71
+
72
+ # 在后台启动系统
73
+ background_tasks.add_task(
74
+ _start_system_background,
75
+ request.config
76
+ )
77
+
78
+ return SystemResponse(
79
+ success=True,
80
+ message="系统启动请求已提交,正在后台启动..."
81
+ )
82
+
83
+ except Exception as e:
84
+ logger.error(f"系统启动失败: {e}", exc_info=True)
85
+ _system_status["status"] = "stopped"
86
+ raise HTTPException(status_code=500, detail=f"系统启动失败: {str(e)}")
87
+
88
+
89
+ @router.post("/stop", response_model=SystemResponse, summary="停止系统")
90
+ async def stop_system():
91
+ """
92
+ 停止语音对话系统
93
+ """
94
+ try:
95
+ if _system_status["status"] == "stopped":
96
+ return SystemResponse(
97
+ success=False,
98
+ message="系统已经停止"
99
+ )
100
+
101
+ # 更新状态
102
+ _system_status["status"] = "stopping"
103
+
104
+ # 模拟停止过程
105
+ await asyncio.sleep(1)
106
+
107
+ _system_status["status"] = "stopped"
108
+ _system_status["start_time"] = None
109
+ _system_status["config"] = None
110
+ _system_status["active_sessions"] = 0
111
+
112
+ return SystemResponse(
113
+ success=True,
114
+ message="系统已成功停止"
115
+ )
116
+
117
+ except Exception as e:
118
+ logger.error(f"系统停止失败: {e}", exc_info=True)
119
+ raise HTTPException(status_code=500, detail=f"系统停止失败: {str(e)}")
120
+
121
+
122
+ @router.post("/restart", response_model=SystemResponse, summary="重启系统")
123
+ async def restart_system(
124
+ request: SystemStartRequest,
125
+ background_tasks: BackgroundTasks
126
+ ):
127
+ """
128
+ 重启语音对话系统
129
+ """
130
+ try:
131
+ # 先停止
132
+ if _system_status["status"] != "stopped":
133
+ await stop_system()
134
+
135
+ # 再启动
136
+ return await start_system(request, background_tasks)
137
+
138
+ except Exception as e:
139
+ logger.error(f"系统重启失败: {e}", exc_info=True)
140
+ raise HTTPException(status_code=500, detail=f"系统重启失败: {str(e)}")
141
+
142
+
143
+ async def _start_system_background(config: SystemConfig):
144
+ """
145
+ 后台启动系统的实际逻辑
146
+ """
147
+ try:
148
+ logger.info("开始启动语音对话系统...")
149
+
150
+ # 模拟启动过程
151
+ await asyncio.sleep(2)
152
+
153
+ # 这里应该调用实际的系统启动逻辑
154
+ # 类似于原来main.py中的launch_system函数
155
+
156
+ _system_status["status"] = "running"
157
+ _system_status["start_time"] = time.time()
158
+
159
+ logger.info("语音对话系统启动成功")
160
+
161
+ except Exception as e:
162
+ logger.error(f"后台启动系统失败: {e}", exc_info=True)
163
+ _system_status["status"] = "stopped"
src/VoiceDialogue/api/routes/tts_routes.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from fastapi import APIRouter, HTTPException, BackgroundTasks
5
+
6
+ from services.audio.audio_generator import tts_config_registry
7
+ from ..schemas.tts_schemas import (
8
+ TTSModelInfo, TTSModelListResponse, TTSModelLoadRequest,
9
+ TTSModelLoadResponse, TTSModelStatusResponse, TTSModelDeleteResponse,
10
+ generate_model_id
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+ router = APIRouter()
15
+
16
+
17
+ @router.get("/models", response_model=TTSModelListResponse, summary="获取TTS模型列表")
18
+ async def list_tts_models():
19
+ """
20
+ 获取所有可用的TTS模型列表
21
+ 只返回BaseTTSConfig中的基础字段,每个模型分配唯一ID
22
+ """
23
+ try:
24
+ all_configs = tts_config_registry.get_all_configs()
25
+ models = []
26
+
27
+ for config in all_configs:
28
+ # 生成唯一ID,但不暴露具体的TTS类型
29
+ model_id = generate_model_id(config.tts_type.value, config.character_name)
30
+
31
+ # 检查模型状态
32
+ if config.is_model_complete():
33
+ status = "downloaded"
34
+ else:
35
+ status = "not_downloaded"
36
+
37
+ model_info = TTSModelInfo(
38
+ id=model_id,
39
+ character_name=config.character_name,
40
+ cover_image=config.cover_image,
41
+ description=config.description,
42
+ file_size=config.file_size,
43
+ is_chinese_voice=config.is_chinese_voice,
44
+ status=status
45
+ )
46
+ models.append(model_info)
47
+
48
+ return TTSModelListResponse(
49
+ models=models,
50
+ total=len(models)
51
+ )
52
+
53
+ except Exception as e:
54
+ logger.error(f"获取TTS模型列表失败: {e}", exc_info=True)
55
+ raise HTTPException(status_code=500, detail=f"获取TTS模型列表失败: {str(e)}")
56
+
57
+
58
+ @router.post("/models/load", response_model=TTSModelLoadResponse, summary="加载TTS模型")
59
+ async def load_tts_model(request: TTSModelLoadRequest, background_tasks: BackgroundTasks):
60
+ """
61
+ 通过模型ID加载TTS模型,不暴露具体的TTS类型
62
+ """
63
+ try:
64
+ # 通过ID找到对应的配置
65
+ config = _find_config_by_id(request.model_id)
66
+ if not config:
67
+ raise HTTPException(status_code=404, detail="模型ID不存在")
68
+
69
+ # 检查模型是否已存在
70
+ if config.is_model_complete():
71
+ return TTSModelLoadResponse(
72
+ success=True,
73
+ message=f"模型 {config.character_name} 已经加载完成",
74
+ model_id=request.model_id
75
+ )
76
+
77
+ # 后台下载模型
78
+ background_tasks.add_task(_download_model_task, config, request.model_id)
79
+
80
+ return TTSModelLoadResponse(
81
+ success=True,
82
+ message=f"模型 {config.character_name} 开始下载",
83
+ model_id=request.model_id
84
+ )
85
+
86
+ except HTTPException:
87
+ raise
88
+ except Exception as e:
89
+ logger.error(f"加载TTS模型失败: {e}", exc_info=True)
90
+ return TTSModelLoadResponse(
91
+ success=False,
92
+ message=f"加载模型失败: {str(e)}",
93
+ model_id=request.model_id
94
+ )
95
+
96
+
97
+ @router.get("/models/{model_id}/status", response_model=TTSModelStatusResponse, summary="获取TTS模型状态")
98
+ async def get_tts_model_status(model_id: str):
99
+ """
100
+ 获取指定TTS模型的状态
101
+ """
102
+ try:
103
+ config = _find_config_by_id(model_id)
104
+ if not config:
105
+ raise HTTPException(status_code=404, detail="模型ID不存在")
106
+
107
+ # 检查模型状态
108
+ if config.is_model_complete():
109
+ status = "downloaded"
110
+ else:
111
+ status = "not_downloaded"
112
+
113
+ return TTSModelStatusResponse(
114
+ model_id=model_id,
115
+ status=status
116
+ )
117
+
118
+ except HTTPException:
119
+ raise
120
+ except Exception as e:
121
+ logger.error(f"获取TTS模型状态失败: {e}", exc_info=True)
122
+ raise HTTPException(status_code=500, detail=f"获取模型状态失败: {str(e)}")
123
+
124
+
125
+ @router.delete("/models/{model_id}", response_model=TTSModelDeleteResponse, summary="删除TTS模型")
126
+ async def delete_tts_model(model_id: str):
127
+ """
128
+ 删除指定的TTS模型
129
+ """
130
+ try:
131
+ config = _find_config_by_id(model_id)
132
+ if not config:
133
+ raise HTTPException(status_code=404, detail="模型ID不存在")
134
+
135
+ # 删除模型文件
136
+ config.delete_model()
137
+
138
+ return TTSModelDeleteResponse(
139
+ success=True,
140
+ message=f"模型 {config.character_name} 删除成功",
141
+ model_id=model_id
142
+ )
143
+
144
+ except HTTPException:
145
+ raise
146
+ except Exception as e:
147
+ logger.error(f"删除TTS模型失败: {e}", exc_info=True)
148
+ return TTSModelDeleteResponse(
149
+ success=False,
150
+ message=f"删除模型失败: {str(e)}",
151
+ model_id=model_id
152
+ )
153
+
154
+
155
+ def _find_config_by_id(model_id: str) -> Optional:
156
+ """通过模型ID找到对应的配置"""
157
+ all_configs = tts_config_registry.get_all_configs()
158
+ for config in all_configs:
159
+ config_id = generate_model_id(config.tts_type.value, config.character_name)
160
+ if config_id == model_id:
161
+ return config
162
+ return None
163
+
164
+
165
+ async def _download_model_task(config, model_id: str):
166
+ """后台下载模型任务"""
167
+ try:
168
+ logger.info(f"开始下载模型: {config.character_name}")
169
+ config.download_model()
170
+ logger.info(f"模型下载完成: {config.character_name}")
171
+ except Exception as e:
172
+ logger.error(f"模型下载失败: {e}", exc_info=True)
src/VoiceDialogue/api/routes/voice_routes.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import time
4
+
5
+ import numpy as np
6
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
7
+
8
+ from ..dependencies.audio_deps import decode_audio_data, encode_audio_data, validate_audio_format
9
+ from ..dependencies.model_deps import get_language_model, get_voice_model
10
+ from ..schemas.voice_schemas import (
11
+ VoiceInput, TextInput, VoiceResponse,
12
+ TTSRequest, TTSResponse, ASRRequest, ASRResponse
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+ router = APIRouter()
17
+
18
+
19
+ @router.post("/chat", response_model=VoiceResponse, summary="语音对话")
20
+ async def voice_chat(
21
+ voice_input: VoiceInput,
22
+ language_model=Depends(get_language_model),
23
+ voice_model=Depends(get_voice_model)
24
+ ):
25
+ """
26
+ 完整的语音对话处理:语音识别 -> 文本生成 -> 语音合成
27
+ """
28
+ start_time = time.time()
29
+
30
+ try:
31
+ # 1. 解码音频数据
32
+ audio_array = decode_audio_data(voice_input.audio_data)
33
+ validate_audio_format(audio_array)
34
+
35
+ # 2. 语音识别 (ASR)
36
+ # 这里应该集成实际的ASR服务
37
+ transcribed_text = await perform_asr(audio_array, voice_input.language)
38
+
39
+ # 3. 文本生成 (LLM)
40
+ generated_text = await generate_response(transcribed_text, language_model)
41
+
42
+ # 4. 语音合成 (TTS)
43
+ audio_response = await synthesize_speech(generated_text, voice_model)
44
+
45
+ processing_time = time.time() - start_time
46
+
47
+ return VoiceResponse(
48
+ transcribed_text=transcribed_text,
49
+ generated_text=generated_text,
50
+ audio_data=audio_response,
51
+ processing_time=processing_time
52
+ )
53
+
54
+ except Exception as e:
55
+ logger.error(f"语音对话处理失败: {e}", exc_info=True)
56
+ raise HTTPException(status_code=500, detail=f"语音对话处理失败: {str(e)}")
57
+
58
+
59
+ @router.post("/text-chat", response_model=VoiceResponse, summary="文本对话")
60
+ async def text_chat(
61
+ text_input: TextInput,
62
+ language_model=Depends(get_language_model),
63
+ voice_model=Depends(get_voice_model)
64
+ ):
65
+ """
66
+ 文本对话处理:文本生成 -> 语音合成
67
+ """
68
+ start_time = time.time()
69
+
70
+ try:
71
+ # 1. 文本生成 (LLM)
72
+ generated_text = await generate_response(text_input.text, language_model)
73
+
74
+ # 2. 语音合成 (TTS)
75
+ audio_response = await synthesize_speech(generated_text, voice_model)
76
+
77
+ processing_time = time.time() - start_time
78
+
79
+ return VoiceResponse(
80
+ transcribed_text=text_input.text,
81
+ generated_text=generated_text,
82
+ audio_data=audio_response,
83
+ processing_time=processing_time
84
+ )
85
+
86
+ except Exception as e:
87
+ logger.error(f"文本对话处理失败: {e}", exc_info=True)
88
+ raise HTTPException(status_code=500, detail=f"文本对话处理失败: {str(e)}")
89
+
90
+
91
+ @router.post("/asr", response_model=ASRResponse, summary="语音识别")
92
+ async def speech_to_text(asr_request: ASRRequest):
93
+ """
94
+ 语音识别服务
95
+ """
96
+ try:
97
+ # 解码音频数据
98
+ audio_array = decode_audio_data(asr_request.audio_data)
99
+ validate_audio_format(audio_array)
100
+
101
+ # 执行语音识别
102
+ transcribed_text = await perform_asr(audio_array, asr_request.language)
103
+
104
+ return ASRResponse(
105
+ transcribed_text=transcribed_text,
106
+ confidence=0.95 # 这里应该返回实际的置信度
107
+ )
108
+
109
+ except Exception as e:
110
+ logger.error(f"语音识别失败: {e}", exc_info=True)
111
+ raise HTTPException(status_code=500, detail=f"语音识别失败: {str(e)}")
112
+
113
+
114
+ @router.post("/tts", response_model=TTSResponse, summary="文本转语音")
115
+ async def text_to_speech(
116
+ tts_request: TTSRequest,
117
+ voice_model=Depends(get_voice_model)
118
+ ):
119
+ """
120
+ 文本转语音服务
121
+ """
122
+ try:
123
+ # 执行语音合成
124
+ audio_data = await synthesize_speech(tts_request.text, voice_model)
125
+
126
+ # 计算音频时长 (这里是估算)
127
+ duration = len(tts_request.text) * 0.1 # 大概每个字符0.1秒
128
+
129
+ return TTSResponse(
130
+ audio_data=audio_data,
131
+ duration=duration
132
+ )
133
+
134
+ except Exception as e:
135
+ logger.error(f"语音合成失败: {e}", exc_info=True)
136
+ raise HTTPException(status_code=500, detail=f"语音合成失败: {str(e)}")
137
+
138
+
139
+ # 辅助函数
140
+ async def perform_asr(audio_array: np.ndarray, language: str) -> str:
141
+ """执行语音识别"""
142
+ # 这里应该集成实际的ASR服务
143
+ # 模拟处理
144
+ await asyncio.sleep(0.1)
145
+ return "这是识别出的文本内容"
146
+
147
+
148
+ async def generate_response(text: str, language_model) -> str:
149
+ """生成文本响应"""
150
+ # 这里应该集成实际的LLM服务
151
+ # 模拟处理
152
+ await asyncio.sleep(0.5)
153
+ return f"针对「{text}」的AI回答:这是一个很好的问题,让我来为您详细解答..."
154
+
155
+
156
+ async def synthesize_speech(text: str, voice_model) -> str:
157
+ """合成语音"""
158
+ # 这里应该集成实际的TTS服务
159
+ # 模拟返回Base64编码的音频数据
160
+ await asyncio.sleep(0.3)
161
+ # 创建一个简单的音频数组作为示例
162
+ dummy_audio = np.random.randn(16000).astype(np.float32) * 0.1 # 1秒的随机音频
163
+ return encode_audio_data(dummy_audio)
src/VoiceDialogue/api/schemas/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .system_schemas import (
2
+ SystemStatusResponse, SystemConfig,
3
+ SystemStartRequest, SystemResponse
4
+ )
5
+ from .voice_schemas import (
6
+ VoiceInput, TextInput, VoiceResponse,
7
+ TTSRequest, TTSResponse, ASRRequest, ASRResponse
8
+ )
9
+
10
+ __all__ = [
11
+ "VoiceInput", "TextInput", "VoiceResponse",
12
+ "TTSRequest", "TTSResponse", "ASRRequest", "ASRResponse",
13
+ "ModelInfo", "ModelListResponse",
14
+ "ModelLoadRequest", "ModelLoadResponse",
15
+ "SystemStatusResponse", "SystemConfig",
16
+ "SystemStartRequest", "SystemResponse"
17
+ ]
src/VoiceDialogue/api/schemas/asr_schemas.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, List, Dict, Optional
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class SupportedLanguagesResponse(BaseModel):
6
+ """支持的语言响应模式"""
7
+ languages: List[str] = Field(..., description="支持的语言列表")
8
+ language_mappings: Dict[str, str] = Field(..., description="语言到ASR引擎的映射")
9
+ asr_engines: List[str] = Field(..., description="可用的ASR引擎列表")
10
+
11
+
12
+ class ASRRegistryResponse(BaseModel):
13
+ """ASR注册表响应模式"""
14
+ registered_asr_types: List[str] = Field(..., description="已注册的ASR类型")
15
+ supported_languages_by_engine: Dict[str, List[str]] = Field(..., description="各引擎支持的语言")
16
+ total_registered_count: int = Field(..., description="注册的ASR引擎总数")
17
+
18
+
19
+ class ASRStatisticsResponse(BaseModel):
20
+ """ASR统计信息响应模式"""
21
+ registered_asr_count: int = Field(..., description="已注册的ASR引擎数量")
22
+ active_instances_count: int = Field(..., description="活动实例数量")
23
+ supported_languages: List[str] = Field(..., description="支持的语言列表")
24
+ language_mappings: Dict[str, str] = Field(..., description="语言映射配置")
25
+ registered_asr_types: List[str] = Field(..., description="已注册的ASR类型")
26
+
27
+
28
+ class ASRInstanceRequest(BaseModel):
29
+ """ASR实例请求模式"""
30
+ language: Literal["zh", "en", "auto"] = Field(..., description="目标语言")
31
+
32
+
33
+ class ASRInstanceResponse(BaseModel):
34
+ """ASR实例响应模式"""
35
+ success: bool = Field(..., description="操作是否成功")
36
+ message: str = Field(..., description="操作结果消息")
37
+ language: str = Field(..., description="语言类型")
38
+ asr_type: str = Field(..., description="使用的ASR引擎类型")
39
+ instance_id: Optional[str] = Field(None, description="实例标识符")
40
+
41
+
42
+ class LanguageMappingRequest(BaseModel):
43
+ """语言映射配置请求模式"""
44
+ language: str = Field(..., description="语言代码")
45
+ asr_type: str = Field(..., description="ASR引擎类型")
46
+
47
+
48
+ class LanguageMappingResponse(BaseModel):
49
+ """语言映射配置响应模式"""
50
+ success: bool = Field(..., description="操作是否成功")
51
+ message: str = Field(..., description="操作结果消息")
52
+ updated_mapping: Dict[str, str] = Field(..., description="更新后的映射关系")
53
+
54
+
55
+ class ASRValidationRequest(BaseModel):
56
+ """ASR语言验证请求模式"""
57
+ language: str = Field(..., description="要验证的语言代码")
58
+
59
+
60
+ class ASRValidationResponse(BaseModel):
61
+ """ASR语言验证响应模式"""
62
+ language: str = Field(..., description="语言代码")
63
+ is_supported: bool = Field(..., description="是否支持")
64
+ optimal_asr: Optional[str] = Field(None, description="最优ASR引擎")
65
+ available_asrs: List[str] = Field(default_factory=list, description="支持该语言的ASR引擎列表")
66
+
67
+
68
+ class CleanupResponse(BaseModel):
69
+ """清理操作响应模式"""
70
+ success: bool = Field(..., description="清理是否成功")
71
+ message: str = Field(..., description="清理结果消息")
72
+ cleared_instances_count: int = Field(..., description="清理的实例数量")
src/VoiceDialogue/api/schemas/system_schemas.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class SystemStatusResponse(BaseModel):
7
+ """系统状态"""
8
+ status: Literal['running', 'stopped', 'starting', 'stopping'] = Field(..., description="系统状态")
9
+ uptime: Optional[float] = Field(None, description="运行时间(秒)")
10
+ active_sessions: int = Field(default=0, description="活跃会话数")
11
+ memory_usage: Optional[float] = Field(None, description="内存使用率")
12
+
13
+
14
+ class SystemConfig(BaseModel):
15
+ """系统配置"""
16
+ user_language: Literal['zh', 'en'] = Field(default='zh', description="用户语言")
17
+ system_prompt: str = Field(..., description="系统提示词")
18
+ tts_speaker: str = Field(default='沈逸', description="TTS语音角色")
19
+ llm_model: Literal['7B', '14B'] = Field(default='14B', description="语言模型规模")
20
+
21
+
22
+ class SystemStartRequest(BaseModel):
23
+ """系统启动请求"""
24
+ config: SystemConfig = Field(..., description="系统配置")
25
+
26
+
27
+ class SystemResponse(BaseModel):
28
+ """系统响应"""
29
+ success: bool = Field(..., description="操作是否成功")
30
+ message: str = Field(..., description="响应消息")
31
+ status: Optional[SystemStatusResponse] = Field(None, description="系统状态")
src/VoiceDialogue/api/schemas/tts_schemas.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Literal
2
+ from pydantic import BaseModel, Field
3
+ import hashlib
4
+
5
+
6
+ class TTSModelInfo(BaseModel):
7
+ """TTS模型基础信息"""
8
+ id: str = Field(..., description="模型唯一标识符")
9
+ character_name: str = Field(..., description="角色名称")
10
+ cover_image: str = Field(..., description="封面图片URL")
11
+ description: str = Field(..., description="模型描述")
12
+ file_size: str = Field(..., description="文件大小")
13
+ is_chinese_voice: bool = Field(..., description="是否为中文语音")
14
+ status: Literal['not_downloaded', 'downloading', 'downloaded', 'failed'] = Field(..., description="模型状态")
15
+
16
+
17
+ class TTSModelListResponse(BaseModel):
18
+ """TTS模型列表响应"""
19
+ models: List[TTSModelInfo] = Field(..., description="TTS模型列表")
20
+ total: int = Field(..., description="模型总数")
21
+
22
+
23
+ class TTSModelLoadRequest(BaseModel):
24
+ """TTS模型加载请求"""
25
+ model_id: str = Field(..., description="要加载的模型ID")
26
+
27
+
28
+ class TTSModelLoadResponse(BaseModel):
29
+ """TTS模型加载响应"""
30
+ success: bool = Field(..., description="是否加载成功")
31
+ message: str = Field(..., description="响应消息")
32
+ model_id: str = Field(..., description="模型ID")
33
+
34
+
35
+ class TTSModelStatusResponse(BaseModel):
36
+ """TTS模型状态响应"""
37
+ model_id: str = Field(..., description="模型ID")
38
+ status: Literal['not_downloaded', 'downloading', 'downloaded', 'failed'] = Field(..., description="模型状态")
39
+ progress: Optional[float] = Field(None, description="下载进度(0-100)")
40
+
41
+
42
+ class TTSModelDeleteResponse(BaseModel):
43
+ """TTS模型删除响应"""
44
+ success: bool = Field(..., description="是否删除成功")
45
+ message: str = Field(..., description="响应消息")
46
+ model_id: str = Field(..., description="模型ID")
47
+
48
+
49
+ def generate_model_id(tts_type: str, character_name: str) -> str:
50
+ """生成模型唯一ID"""
51
+ combined = f"{tts_type}:{character_name}"
52
+ return hashlib.md5(combined.encode()).hexdigest()[:16]
src/VoiceDialogue/api/schemas/voice_schemas.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Optional, Literal
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class VoiceInput(BaseModel):
8
+ """语音输入请求模式"""
9
+ audio_data: str = Field(..., description="Base64编码的音频数据")
10
+ language: Literal['zh', 'en'] = Field(default='zh', description="语音语言")
11
+
12
+
13
+ class TextInput(BaseModel):
14
+ """文本输入请求模式"""
15
+ text: str = Field(..., description="输入文本", min_length=1, max_length=1000)
16
+ language: Literal['zh', 'en'] = Field(default='zh', description="文本语言")
17
+
18
+
19
+ class VoiceResponse(BaseModel):
20
+ """语音响应模式"""
21
+ transcribed_text: Optional[str] = Field(None, description="转录的文本")
22
+ generated_text: str = Field(..., description="生成的回答文本")
23
+ audio_data: str = Field(..., description="Base64编码的音频响应")
24
+ processing_time: float = Field(..., description="处理时间(秒)")
25
+ timestamp: datetime = Field(default_factory=datetime.now, description="响应时间戳")
26
+
27
+
28
+ class TTSRequest(BaseModel):
29
+ """文本转语音请求模式"""
30
+ text: str = Field(..., description="要转换的文本", min_length=1, max_length=1000)
31
+ speaker: str = Field(default='沈逸', description="语音角色")
32
+
33
+
34
+ class TTSResponse(BaseModel):
35
+ """文本转语音响应模式"""
36
+ audio_data: str = Field(..., description="Base64编码的音频数据")
37
+ duration: float = Field(..., description="音频时长(秒)")
38
+
39
+
40
+ class ASRRequest(BaseModel):
41
+ """语音识别请求模式"""
42
+ audio_data: str = Field(..., description="Base64编码的音频数据")
43
+ language: Literal['zh', 'en'] = Field(default='zh', description="语音语言")
44
+
45
+
46
+ class ASRResponse(BaseModel):
47
+ """语音识别响应模式"""
48
+ transcribed_text: str = Field(..., description="识别出的文本")
49
+ confidence: float = Field(..., description="识别置信度")
src/VoiceDialogue/api/server.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 独立的API服务器启动脚本
3
+ 可以直接运行此脚本启动API服务器,无需通过main.py
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import uvicorn
10
+
11
+ # 添加项目根目录到Python路径
12
+ project_root = Path(__file__).parent.parent
13
+ sys.path.insert(0, str(project_root))
14
+
15
+ # 加载第三方库
16
+ from config.paths import load_third_party
17
+
18
+ load_third_party()
19
+
20
+
21
+ def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
22
+ """运行API服务器"""
23
+ print(f"""
24
+ {"=" * 80}
25
+ VoiceDialogue API Server
26
+ {"=" * 80}
27
+ 服务器地址: http://{host}:{port}
28
+ API文档: http://{host}:{port}/docs
29
+ ReDoc文档: http://{host}:{port}/redoc
30
+ 热重载: {'启用' if reload else '禁用'}
31
+ {"=" * 80}
32
+ """)
33
+
34
+ uvicorn.run(
35
+ "api.app:app",
36
+ host=host,
37
+ port=port,
38
+ reload=reload,
39
+ log_level="info",
40
+ access_log=True
41
+ )
42
+
43
+
44
+ if __name__ == "__main__":
45
+ import argparse
46
+
47
+ parser = argparse.ArgumentParser(description="VoiceDialogue API服务器")
48
+ parser.add_argument("--host", default="0.0.0.0", help="服务器主机地址")
49
+ parser.add_argument("--port", "-p", type=int, default=8000, help="服务器端口")
50
+ parser.add_argument("--reload", action="store_true", help="启用热重载")
51
+
52
+ args = parser.parse_args()
53
+ run_server(args.host, args.port, args.reload)
src/VoiceDialogue/main.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import time
2
  import typing
3
  from multiprocessing import Queue
4
  from pathlib import Path
5
 
 
 
6
  from config.paths import load_third_party
7
 
8
  load_third_party()
@@ -121,32 +124,141 @@ def launch_system(
121
  thread.join()
122
 
123
 
124
- def main():
125
  """
126
- 主程序入口函数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- 配置并启动语音对话系统的默认设置。当前配置:
129
- - 用户语言:中文 ('zh')
130
- - TTS说话人:沈逸
131
 
132
- 该函数可以根据需要修改默认配置,或者扩展为支持命令行参数。
 
 
 
 
 
 
 
 
133
 
134
- Returns:
135
- None
136
 
137
- Example:
138
- 直接运行脚本:
139
- $ python main.py
140
 
141
- 系统将使用默认配置启动语音对话服务
142
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- user_language: typing.Literal['zh', 'en'] = 'zh'
145
 
146
- # '罗翔', '马保国', '沈逸', '杨幂', '周杰伦', '马云'
147
- tts_speaker = '沈逸'
148
 
149
- launch_system(user_language, tts_speaker)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  if __name__ == '__main__':
 
1
+ import argparse
2
  import time
3
  import typing
4
  from multiprocessing import Queue
5
  from pathlib import Path
6
 
7
+ import uvicorn
8
+
9
  from config.paths import load_third_party
10
 
11
  load_third_party()
 
124
  thread.join()
125
 
126
 
127
+ def launch_api_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
128
  """
129
+ 启动API服务器
130
+
131
+ Args:
132
+ host (str): 服务器主机地址,默认为 "0.0.0.0"
133
+ port (int): 服务器端口,默认为 8000
134
+ reload (bool): 是否启用热重载,默认为 False
135
+ """
136
+ print(f'{"=" * 80}\n正在启动API服务器...\n{"=" * 80}')
137
+ print(f"服务器地址: http://{host}:{port}")
138
+ print(f"API文档: http://{host}:{port}/docs")
139
+ print(f"热重载: {'启用' if reload else '禁用'}")
140
+ print(f'{"=" * 80}')
141
+
142
+ # 导入并启动FastAPI应用
143
+ uvicorn.run(
144
+ "api.app:app",
145
+ host=host,
146
+ port=port,
147
+ reload=reload,
148
+ log_level="info"
149
+ )
150
 
 
 
 
151
 
152
+ def create_argument_parser():
153
+ """创建命令行参数解析器"""
154
+ parser = argparse.ArgumentParser(
155
+ description="VoiceDialogue - 语音对话系统",
156
+ formatter_class=argparse.RawDescriptionHelpFormatter,
157
+ epilog="""
158
+ 示例用法:
159
+ # 启动命令行模式(默认)
160
+ python main.py
161
 
162
+ # 启动命令行模式并指定参数
163
+ python main.py --mode cli --language zh --speaker 沈逸
164
 
165
+ # 启动API服务器
166
+ python main.py --mode api
 
167
 
168
+ # 启动API服务器并指定端口
169
+ python main.py --mode api --port 9000
170
+
171
+ # 启动API服务器并启用热重载(开发模式)
172
+ python main.py --mode api --port 8000 --reload
173
+
174
+ 支持的说话人:
175
+ 罗翔, 马保国, 沈逸, 杨幂, 周杰伦, 马云
176
+ """
177
+ )
178
+
179
+ # 运行模式选择
180
+ parser.add_argument(
181
+ '--mode', '-m',
182
+ choices=['cli', 'api'],
183
+ default='cli',
184
+ help='运行模式: cli=命令行模式, api=API服务器模式 (默认: cli)'
185
+ )
186
+
187
+ # 命令行模式参数
188
+ cli_group = parser.add_argument_group('命令行模式参数')
189
+ cli_group.add_argument(
190
+ '--language', '-l',
191
+ choices=['zh', 'en'],
192
+ default='zh',
193
+ help='用户语言: zh=中文, en=英文 (默认: zh)'
194
+ )
195
+ cli_group.add_argument(
196
+ '--speaker', '-s',
197
+ choices=['罗翔', '马保国', '沈逸', '杨幂', '周杰伦', '马云'],
198
+ default='沈逸',
199
+ help='TTS说话人 (默认: 沈逸)'
200
+ )
201
+
202
+ # API服务器模式参数
203
+ api_group = parser.add_argument_group('API服务器模式参数')
204
+ api_group.add_argument(
205
+ '--host',
206
+ default='0.0.0.0',
207
+ help='服务器主机地址 (默认: 0.0.0.0)'
208
+ )
209
+ api_group.add_argument(
210
+ '--port', '-p',
211
+ type=int,
212
+ default=8000,
213
+ help='服务器端口 (默认: 8000)'
214
+ )
215
+ api_group.add_argument(
216
+ '--reload',
217
+ action='store_true',
218
+ help='启用热重载(开发模式)'
219
+ )
220
 
221
+ return parser
222
 
 
 
223
 
224
+ def main():
225
+ """
226
+ 主程序入口函数
227
+
228
+ 根据命令行参数选择启动模式:
229
+ - cli: 启动命令行语音对话系统
230
+ - api: 启动HTTP API服务器
231
+ """
232
+ parser = create_argument_parser()
233
+ args = parser.parse_args()
234
+
235
+ print(f"""
236
+ {"=" * 80}
237
+ VoiceDialogue - 语音对话系统
238
+ {"=" * 80}
239
+ 运行模式: {args.mode.upper()}
240
+ {"=" * 80}
241
+ """)
242
+
243
+ try:
244
+ if args.mode == 'cli':
245
+ print(f"语言设置: {args.language}")
246
+ print(f"说话人: {args.speaker}")
247
+ print("正在启动命令行语音对话系统...")
248
+ launch_system(args.language, args.speaker)
249
+
250
+ elif args.mode == 'api':
251
+ launch_api_server(
252
+ host=args.host,
253
+ port=args.port,
254
+ reload=args.reload
255
+ )
256
+
257
+ except KeyboardInterrupt:
258
+ print("\n程序被用户中断")
259
+ except Exception as e:
260
+ print(f"程序运行出错: {e}")
261
+ raise
262
 
263
 
264
  if __name__ == '__main__':