liumaolin commited on
Commit
fb6d02a
·
1 Parent(s): 87a7384

Enhance TTS model handling: add dynamic status tracking, model downloading, and default system configuration initialization with API updates to manage active and default TTS models effectively.

Browse files
src/VoiceDialogue/api/core/lifespan.py CHANGED
@@ -4,10 +4,12 @@ from contextlib import asynccontextmanager
4
 
5
  from fastapi import FastAPI
6
 
 
7
  from utils import get_system_language
8
  from .config import TTSConfigInitializer
9
  from .service_factories import get_core_voice_service_definitions
10
  from .service_manager import ServiceManager
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -33,6 +35,19 @@ class LifespanManager:
33
  tts_config = TTSConfigInitializer.initialize()
34
  self._update_app_state(tts_config)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # 获取服务定义
37
  service_definitions = get_core_voice_service_definitions(system_language)
38
 
@@ -45,6 +60,8 @@ class LifespanManager:
45
  "system_running": True,
46
  "system_language": system_language,
47
  "current_asr_language": system_language,
 
 
48
  })
49
 
50
  # 记录启动信息
 
4
 
5
  from fastapi import FastAPI
6
 
7
+ from services.audio.audio_generator import tts_config_registry
8
  from utils import get_system_language
9
  from .config import TTSConfigInitializer
10
  from .service_factories import get_core_voice_service_definitions
11
  from .service_manager import ServiceManager
12
+ from ..schemas.tts_schemas import generate_model_id
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
35
  tts_config = TTSConfigInitializer.initialize()
36
  self._update_app_state(tts_config)
37
 
38
+ default_tts_config = tts_config_registry.get_default_config_for_system()
39
+ current_tts_model_id = None
40
+ current_tts_character_name = None
41
+
42
+ if default_tts_config:
43
+ current_tts_model_id = generate_model_id(
44
+ default_tts_config.tts_type.value,
45
+ default_tts_config.character_name
46
+ )
47
+ current_tts_character_name = default_tts_config.character_name
48
+ logger.info(f"系统默认TTS模型: {current_tts_character_name} (ID: {current_tts_model_id})")
49
+
50
+
51
  # 获取服务定义
52
  service_definitions = get_core_voice_service_definitions(system_language)
53
 
 
60
  "system_running": True,
61
  "system_language": system_language,
62
  "current_asr_language": system_language,
63
+ "current_tts_model_id": current_tts_model_id,
64
+ "current_tts_character_name": current_tts_character_name,
65
  })
66
 
67
  # 记录启动信息
src/VoiceDialogue/api/core/service_factories.py CHANGED
@@ -144,6 +144,16 @@ def get_asr_worker_service_definition(language: str) -> ServiceDefinition:
144
  )
145
 
146
 
 
 
 
 
 
 
 
 
 
 
147
  def get_service_health_checkers() -> dict:
148
  """获取服务健康检查器映射"""
149
  return {
 
144
  )
145
 
146
 
147
+ def get_tts_audio_generator_service_definition(tts_config: BaseTTSConfig = None) -> ServiceDefinition:
148
+ """获取TTS音频生成服务定义"""
149
+ return ServiceDefinition(
150
+ name="tts_audio_generator",
151
+ factory=lambda: ServiceFactories.create_tts_audio_generator(tts_config),
152
+ dependencies=["llm_generator"],
153
+ startup_timeout=45
154
+ )
155
+
156
+
157
  def get_service_health_checkers() -> dict:
158
  """获取服务健康检查器映射"""
159
  return {
src/VoiceDialogue/api/routes/tts_routes.py CHANGED
@@ -1,38 +1,76 @@
1
  import logging
2
  from typing import Optional
3
 
4
- from fastapi import APIRouter, HTTPException, BackgroundTasks
5
 
6
  from services.audio.audio_generator import tts_config_registry
 
7
  from ..schemas.tts_schemas import (
8
  TTSModelInfo, TTSModelListResponse, TTSModelLoadRequest,
9
- TTSModelLoadResponse, TTSModelStatusResponse, TTSModelDeleteResponse,
10
- generate_model_id
11
  )
12
 
13
  logger = logging.getLogger(__name__)
14
  router = APIRouter()
15
 
 
 
 
 
 
 
 
 
 
16
 
17
  @router.get("/models", response_model=TTSModelListResponse, summary="获取TTS模型列表")
18
- async def list_tts_models():
19
  """
20
  获取所有可用的TTS模型列表
21
- 只返回BaseTTSConfig中的基础字段,每个模型分配唯一ID
22
  """
23
  try:
24
  all_configs = tts_config_registry.get_all_configs()
25
  models = []
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  for config in all_configs:
28
  # 生成唯一ID,但不暴露具体的TTS类型
29
  model_id = generate_model_id(config.tts_type.value, config.character_name)
30
 
31
  # 检查模型状态
32
  if config.is_model_complete():
33
- status = "downloaded"
 
 
 
 
 
 
 
 
 
 
34
  else:
35
- status = "not_downloaded"
 
 
 
 
 
 
 
 
36
 
37
  model_info = TTSModelInfo(
38
  id=model_id,
@@ -46,8 +84,10 @@ async def list_tts_models():
46
  models.append(model_info)
47
 
48
  return TTSModelListResponse(
 
49
  models=models,
50
- total=len(models)
 
51
  )
52
 
53
  except Exception as e:
@@ -56,37 +96,63 @@ async def list_tts_models():
56
 
57
 
58
  @router.post("/models/load", response_model=TTSModelLoadResponse, summary="加载TTS模型")
59
- async def load_tts_model(request: TTSModelLoadRequest, background_tasks: BackgroundTasks):
 
 
 
 
60
  """
61
- 通过模型ID加载TTS模型,不暴露具体的TTS类型
62
  """
63
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # 通过ID找到对应的配置
65
  config = _find_config_by_id(request.model_id)
66
  if not config:
67
  raise HTTPException(status_code=404, detail="模型ID不存在")
68
 
69
- # 检查模型是否已存在
70
  if config.is_model_complete():
71
- return TTSModelLoadResponse(
72
- success=True,
73
- message=f"模型 {config.character_name} 已经加载完成",
74
- model_id=request.model_id
75
- )
76
-
77
- # 后台下载模型
78
- background_tasks.add_task(_download_model_task, config, request.model_id)
79
-
80
- return TTSModelLoadResponse(
81
- success=True,
82
- message=f"模型 {config.character_name} 开始下载",
83
- model_id=request.model_id
84
- )
85
 
86
  except HTTPException:
87
  raise
 
 
 
 
 
88
  except Exception as e:
89
  logger.error(f"加载TTS模型失败: {e}", exc_info=True)
 
 
90
  return TTSModelLoadResponse(
91
  success=False,
92
  message=f"加载模型失败: {str(e)}",
@@ -94,62 +160,234 @@ async def load_tts_model(request: TTSModelLoadRequest, background_tasks: Backgro
94
  )
95
 
96
 
97
- @router.get("/models/{model_id}/status", response_model=TTSModelStatusResponse, summary="获取TTS模型状态")
98
- async def get_tts_model_status(model_id: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
- 获取指定TTS模型的状态
101
  """
102
  try:
103
- config = _find_config_by_id(model_id)
104
- if not config:
105
- raise HTTPException(status_code=404, detail="模型ID不存在")
106
 
107
- # 检查模型状态
108
- if config.is_model_complete():
109
- status = "downloaded"
110
- else:
111
- status = "not_downloaded"
112
 
113
- return TTSModelStatusResponse(
114
- model_id=model_id,
115
- status=status
116
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- except HTTPException:
119
- raise
120
  except Exception as e:
121
- logger.error(f"获取TTS模型状态失败: {e}", exc_info=True)
122
- raise HTTPException(status_code=500, detail=f"获取模型状态失败: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
- @router.delete("/models/{model_id}", response_model=TTSModelDeleteResponse, summary="删除TTS模型")
126
- async def delete_tts_model(model_id: str):
127
  """
128
- 删除指定的TTS模型
129
  """
130
  try:
131
  config = _find_config_by_id(model_id)
132
  if not config:
133
  raise HTTPException(status_code=404, detail="模型ID不存在")
134
 
135
- # 删除模型文件
136
- config.delete_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- return TTSModelDeleteResponse(
139
- success=True,
140
- message=f"模型 {config.character_name} 删除成功",
141
- model_id=model_id
142
  )
143
 
144
  except HTTPException:
145
  raise
146
  except Exception as e:
147
- logger.error(f"删除TTS模型失败: {e}", exc_info=True)
148
- return TTSModelDeleteResponse(
149
- success=False,
150
- message=f"删除模型失败: {str(e)}",
151
- model_id=model_id
152
- )
153
 
154
 
155
  def _find_config_by_id(model_id: str) -> Optional:
@@ -160,13 +398,3 @@ def _find_config_by_id(model_id: str) -> Optional:
160
  if config_id == model_id:
161
  return config
162
  return None
163
-
164
-
165
- async def _download_model_task(config, model_id: str):
166
- """后台下载模型任务"""
167
- try:
168
- logger.info(f"开始下载模型: {config.character_name}")
169
- config.download_model()
170
- logger.info(f"模型下载完成: {config.character_name}")
171
- except Exception as e:
172
- logger.error(f"模型下载失败: {e}", exc_info=True)
 
1
  import logging
2
  from typing import Optional
3
 
4
+ from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
5
 
6
  from services.audio.audio_generator import tts_config_registry
7
+ from ..core.service_factories import get_tts_audio_generator_service_definition
8
  from ..schemas.tts_schemas import (
9
  TTSModelInfo, TTSModelListResponse, TTSModelLoadRequest,
10
+ TTSModelLoadResponse, TTSModelStatusResponse, generate_model_id
 
11
  )
12
 
13
  logger = logging.getLogger(__name__)
14
  router = APIRouter()
15
 
16
+ # TTS模型加载状态管理
17
+ _tts_loading_status = {
18
+ "status": "idle", # idle, loading, completed, failed
19
+ "current_model_id": None,
20
+ "current_character_name": None,
21
+ "message": None,
22
+ "progress": 0.0
23
+ }
24
+
25
 
26
  @router.get("/models", response_model=TTSModelListResponse, summary="获取TTS模型列表")
27
+ async def list_tts_models(fastapi_request: Request):
28
  """
29
  获取所有可用的TTS模型列表
 
30
  """
31
  try:
32
  all_configs = tts_config_registry.get_all_configs()
33
  models = []
34
 
35
+ # 获取当前系统默认加载的TTS模型信息
36
+ current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None)
37
+ current_tts_character_name = getattr(fastapi_request.app.state, "current_tts_character_name", None)
38
+
39
+ # 如果没有从请求状态获取到当前模型,尝试从系统默认配置获取
40
+ if not current_tts_model_id:
41
+ default_config = tts_config_registry.get_default_config_for_system()
42
+ if default_config:
43
+ current_tts_model_id = generate_model_id(default_config.tts_type.value, default_config.character_name)
44
+ current_tts_character_name = default_config.character_name
45
+ logger.info(f"使用系统默认TTS模型: {current_tts_character_name} (ID: {current_tts_model_id})")
46
+
47
  for config in all_configs:
48
  # 生成唯一ID,但不暴露具体的TTS类型
49
  model_id = generate_model_id(config.tts_type.value, config.character_name)
50
 
51
  # 检查模型状态
52
  if config.is_model_complete():
53
+ # 如果是当前系统加载的模型,或者是正在加载的模型,优先显示正确状态
54
+ if current_tts_model_id == model_id:
55
+ status = "downloaded" # 系统已加载的模型
56
+ elif (_tts_loading_status["status"] == "loading" and
57
+ _tts_loading_status["current_model_id"] == model_id):
58
+ status = "downloading"
59
+ elif (_tts_loading_status["status"] == "failed" and
60
+ _tts_loading_status["current_model_id"] == model_id):
61
+ status = "failed"
62
+ else:
63
+ status = "downloaded"
64
  else:
65
+ # 模型未完整下载
66
+ if (_tts_loading_status["status"] == "loading" and
67
+ _tts_loading_status["current_model_id"] == model_id):
68
+ status = "downloading"
69
+ elif (_tts_loading_status["status"] == "failed" and
70
+ _tts_loading_status["current_model_id"] == model_id):
71
+ status = "failed"
72
+ else:
73
+ status = "not_downloaded"
74
 
75
  model_info = TTSModelInfo(
76
  id=model_id,
 
84
  models.append(model_info)
85
 
86
  return TTSModelListResponse(
87
+ total=len(models),
88
  models=models,
89
+ current_model_id=current_tts_model_id,
90
+ current_character_name=current_tts_character_name
91
  )
92
 
93
  except Exception as e:
 
96
 
97
 
98
  @router.post("/models/load", response_model=TTSModelLoadResponse, summary="加载TTS模型")
99
+ async def load_tts_model(
100
+ request: TTSModelLoadRequest,
101
+ fastapi_request: Request,
102
+ background_tasks: BackgroundTasks,
103
+ ):
104
  """
105
+ 通过模型ID加载TTS模型
106
  """
107
  try:
108
+ if _tts_loading_status["status"] == "loading":
109
+ current_loading_model = _tts_loading_status["current_model_id"]
110
+ if current_loading_model == request.model_id:
111
+ return TTSModelLoadResponse(
112
+ success=True,
113
+ message=f"模型 {_tts_loading_status['current_character_name']} 正在加载中...",
114
+ model_id=request.model_id
115
+ )
116
+ else:
117
+ return TTSModelLoadResponse(
118
+ success=False,
119
+ message="另一个模型正在加载中,请稍后重试",
120
+ model_id=request.model_id
121
+ )
122
+
123
  # 通过ID找到对应的配置
124
  config = _find_config_by_id(request.model_id)
125
  if not config:
126
  raise HTTPException(status_code=404, detail="模型ID不存在")
127
 
128
+ # 检查模型是否已经完整下载
129
  if config.is_model_complete():
130
+ # 检查是否是当前系统已加载的模型
131
+ current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None)
132
+ if current_tts_model_id == request.model_id:
133
+ return TTSModelLoadResponse(
134
+ success=True,
135
+ message=f"模型 {config.character_name} 已是当前系统默认模型",
136
+ model_id=request.model_id
137
+ )
138
+ else:
139
+ # 需要切换到新的模型
140
+ return await _switch_tts_model(request, config, fastapi_request, background_tasks)
141
+ else:
142
+ # 模型未下载,需要先下载再加载
143
+ return await _download_and_load_tts_model(request, config, fastapi_request, background_tasks)
144
 
145
  except HTTPException:
146
  raise
147
+ except ValueError as e:
148
+ logger.warning(f"加载TTS模型失败 - 参数错误: {e}")
149
+ _tts_loading_status["status"] = "failed"
150
+ _tts_loading_status["message"] = str(e)
151
+ raise HTTPException(status_code=400, detail=str(e))
152
  except Exception as e:
153
  logger.error(f"加载TTS模型失败: {e}", exc_info=True)
154
+ _tts_loading_status["status"] = "failed"
155
+ _tts_loading_status["message"] = str(e)
156
  return TTSModelLoadResponse(
157
  success=False,
158
  message=f"加载模型失败: {str(e)}",
 
160
  )
161
 
162
 
163
+ async def _switch_tts_model(
164
+ request: TTSModelLoadRequest,
165
+ config,
166
+ fastapi_request: Request,
167
+ background_tasks: BackgroundTasks
168
+ ) -> TTSModelLoadResponse:
169
+ """切换到已下载的TTS模型"""
170
+ # 更新状态为加载中
171
+ _tts_loading_status["status"] = "loading"
172
+ _tts_loading_status["current_model_id"] = request.model_id
173
+ _tts_loading_status["current_character_name"] = config.character_name
174
+ _tts_loading_status["message"] = "正在切换TTS模型..."
175
+ _tts_loading_status["progress"] = 0.0
176
+
177
+ # 在后台执行模型切换任务
178
+ background_tasks.add_task(
179
+ _switch_tts_model_background,
180
+ config,
181
+ request.model_id,
182
+ fastapi_request
183
+ )
184
+
185
+ return TTSModelLoadResponse(
186
+ success=True,
187
+ message=f"模型 {config.character_name} 切换请求已提交,正在后台切换...",
188
+ model_id=request.model_id
189
+ )
190
+
191
+
192
+ async def _download_and_load_tts_model(
193
+ request: TTSModelLoadRequest,
194
+ config,
195
+ fastapi_request: Request,
196
+ background_tasks: BackgroundTasks
197
+ ) -> TTSModelLoadResponse:
198
+ """下载并加载TTS模型"""
199
+ # 更新状态为加载中
200
+ _tts_loading_status["status"] = "loading"
201
+ _tts_loading_status["current_model_id"] = request.model_id
202
+ _tts_loading_status["current_character_name"] = config.character_name
203
+ _tts_loading_status["message"] = "正在下载TTS模型..."
204
+ _tts_loading_status["progress"] = 0.0
205
+
206
+ # 在后台执行下载和加载任务
207
+ background_tasks.add_task(
208
+ _download_and_load_tts_model_background,
209
+ config,
210
+ request.model_id,
211
+ fastapi_request
212
+ )
213
+
214
+ return TTSModelLoadResponse(
215
+ success=True,
216
+ message=f"模型 {config.character_name} 下载和加载请求已提交,正在后台处理...",
217
+ model_id=request.model_id
218
+ )
219
+
220
+
221
+ async def _switch_tts_model_background(config, model_id: str, fastapi_request: Request):
222
  """
223
+ 后台切换TTS模型的实际逻辑
224
  """
225
  try:
226
+ logger.info(f"开始切换TTS模型: {config.character_name}")
 
 
227
 
228
+ # 获取服务管理器
229
+ service_manager = getattr(fastapi_request.app.state, "service_manager", None)
230
+ if not service_manager:
231
+ raise RuntimeError("服务管理器未初始化")
 
232
 
233
+ _tts_loading_status["progress"] = 20.0
234
+ _tts_loading_status["message"] = "正在停止当前TTS服务..."
235
+
236
+ # 停止当前的TTS服务
237
+ if service_manager.is_service_running("tts_audio_generator"):
238
+ service_manager._stop_single_service("tts_audio_generator")
239
+ logger.info("已停止当前TTS服务")
240
+
241
+ _tts_loading_status["progress"] = 50.0
242
+ _tts_loading_status["message"] = "正在启动新的TTS服务..."
243
+
244
+ # 使用新配置创建TTS服务定义
245
+ new_tts_service_def = get_tts_audio_generator_service_definition(config)
246
+
247
+ # 启动新的TTS服务
248
+ success = service_manager.start_service(new_tts_service_def)
249
+ if not success:
250
+ raise RuntimeError("新TTS服务启动失败")
251
+
252
+ _tts_loading_status["progress"] = 90.0
253
+ _tts_loading_status["message"] = "正在验证服务状态..."
254
+
255
+ # 更新请求状态中的当前模型信息
256
+ fastapi_request.app.state.current_tts_model_id = model_id
257
+ fastapi_request.app.state.current_tts_character_name = config.character_name
258
+
259
+ # 更新状态为完成
260
+ _tts_loading_status["status"] = "completed"
261
+ _tts_loading_status["progress"] = 100.0
262
+ _tts_loading_status["message"] = f"成功切换到TTS模型: {config.character_name}"
263
+
264
+ logger.info(f"TTS模型切换成功: {config.character_name}")
265
 
 
 
266
  except Exception as e:
267
+ logger.error(f"后台切换TTS模型失败: {e}", exc_info=True)
268
+ _tts_loading_status["status"] = "failed"
269
+ _tts_loading_status["message"] = str(e)
270
+ _tts_loading_status["progress"] = 0.0
271
+
272
+
273
+ async def _download_and_load_tts_model_background(config, model_id: str, fastapi_request: Request):
274
+ """
275
+ 后台下载并加载TTS模型的实际逻辑
276
+ """
277
+ try:
278
+ logger.info(f"开始下载并加载TTS模型: {config.character_name}")
279
+
280
+ _tts_loading_status["progress"] = 10.0
281
+ _tts_loading_status["message"] = "正在准备下载..."
282
+
283
+ # 执行实际的模型下载
284
+ _tts_loading_status["progress"] = 30.0
285
+ _tts_loading_status["message"] = "正在下载模型文件..."
286
+
287
+ config.download_model()
288
+
289
+ _tts_loading_status["progress"] = 70.0
290
+ _tts_loading_status["message"] = "正在验证模型文件..."
291
+
292
+ # 验证模型是否下载成功
293
+ if not config.is_model_complete():
294
+ raise RuntimeError("模型文件下载不完整")
295
+
296
+ # 获取服务管理器
297
+ service_manager = getattr(fastapi_request.app.state, "service_manager", None)
298
+ if not service_manager:
299
+ raise RuntimeError("服务管理器未初始化")
300
+
301
+ _tts_loading_status["progress"] = 80.0
302
+ _tts_loading_status["message"] = "正在停止当前TTS服务..."
303
+
304
+ # 停止当前的TTS服务
305
+ if service_manager.is_service_running("tts_audio_generator"):
306
+ service_manager._stop_single_service("tts_audio_generator")
307
+ logger.info("已停止当前TTS服务")
308
+
309
+ _tts_loading_status["progress"] = 90.0
310
+ _tts_loading_status["message"] = "正在启动新的TTS服务..."
311
+
312
+ # 使用新配置创建TTS服务定义
313
+ new_tts_service_def = get_tts_audio_generator_service_definition(config)
314
+
315
+ # 启动新的TTS服务
316
+ success = service_manager.start_service(new_tts_service_def)
317
+ if not success:
318
+ raise RuntimeError("新TTS服务启动失败")
319
+
320
+ # 更新请求状态中的当前模型信息
321
+ fastapi_request.app.state.current_tts_model_id = model_id
322
+ fastapi_request.app.state.current_tts_character_name = config.character_name
323
+
324
+ # 更新状态为完成
325
+ _tts_loading_status["status"] = "completed"
326
+ _tts_loading_status["progress"] = 100.0
327
+ _tts_loading_status["message"] = f"成功下载并加载TTS模型: {config.character_name}"
328
+
329
+ logger.info(f"TTS模型下载并加载成功: {config.character_name}")
330
+
331
+ except Exception as e:
332
+ logger.error(f"后台下载并加载TTS模型失败: {e}", exc_info=True)
333
+ _tts_loading_status["status"] = "failed"
334
+ _tts_loading_status["message"] = str(e)
335
+ _tts_loading_status["progress"] = 0.0
336
 
337
 
338
+ @router.get("/models/{model_id}/status", response_model=TTSModelStatusResponse, summary="获取TTS模型状态")
339
+ async def get_tts_model_status(model_id: str, fastapi_request: Request):
340
  """
341
+ 获取指定TTS模型的状态
342
  """
343
  try:
344
  config = _find_config_by_id(model_id)
345
  if not config:
346
  raise HTTPException(status_code=404, detail="模型ID不存在")
347
 
348
+ # 获取当前系统加载的模型
349
+ current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None)
350
+
351
+ # 检查模型状态
352
+ if config.is_model_complete():
353
+ if current_tts_model_id == model_id:
354
+ status = "downloaded" # 当前系统加载的模型
355
+ progress = 100.0
356
+ elif (_tts_loading_status["status"] == "loading" and
357
+ _tts_loading_status["current_model_id"] == model_id):
358
+ status = "downloading"
359
+ progress = _tts_loading_status["progress"]
360
+ elif (_tts_loading_status["status"] == "failed" and
361
+ _tts_loading_status["current_model_id"] == model_id):
362
+ status = "failed"
363
+ progress = 0.0
364
+ else:
365
+ status = "downloaded"
366
+ progress = 100.0
367
+ else:
368
+ if (_tts_loading_status["status"] == "loading" and
369
+ _tts_loading_status["current_model_id"] == model_id):
370
+ status = "downloading"
371
+ progress = _tts_loading_status["progress"]
372
+ elif (_tts_loading_status["status"] == "failed" and
373
+ _tts_loading_status["current_model_id"] == model_id):
374
+ status = "failed"
375
+ progress = 0.0
376
+ else:
377
+ status = "not_downloaded"
378
+ progress = 0.0
379
 
380
+ return TTSModelStatusResponse(
381
+ model_id=model_id,
382
+ status=status,
383
+ progress=progress
384
  )
385
 
386
  except HTTPException:
387
  raise
388
  except Exception as e:
389
+ logger.error(f"获取TTS模型状态失败: {e}", exc_info=True)
390
+ raise HTTPException(status_code=500, detail=f"获取模型状态失败: {str(e)}")
 
 
 
 
391
 
392
 
393
  def _find_config_by_id(model_id: str) -> Optional:
 
398
  if config_id == model_id:
399
  return config
400
  return None
 
 
 
 
 
 
 
 
 
 
src/VoiceDialogue/api/schemas/tts_schemas.py CHANGED
@@ -16,8 +16,10 @@ class TTSModelInfo(BaseModel):
16
 
17
  class TTSModelListResponse(BaseModel):
18
  """TTS模型列表响应"""
19
- models: List[TTSModelInfo] = Field(..., description="TTS模型列表")
20
  total: int = Field(..., description="模型总数")
 
 
 
21
 
22
 
23
  class TTSModelLoadRequest(BaseModel):
 
16
 
17
  class TTSModelListResponse(BaseModel):
18
  """TTS模型列表响应"""
 
19
  total: int = Field(..., description="模型总数")
20
+ models: List[TTSModelInfo] = Field(..., description="TTS模型列表")
21
+ current_model_id: Optional[str] = Field(None, description="当前使用的模型ID")
22
+ current_character_name: Optional[str] = Field(None, description="当前使用的角色名称")
23
 
24
 
25
  class TTSModelLoadRequest(BaseModel):