| from typing import Optional |
|
|
| from fastapi import APIRouter, HTTPException, BackgroundTasks, Request |
| from fastapi.responses import FileResponse |
|
|
| from voice_dialogue.tts import tts_config_registry |
| from voice_dialogue.utils.logger import logger |
| from ..core.service_factories import get_tts_audio_generator_service_definition |
| from ..schemas.tts_schemas import ( |
| TTSModelInfo, TTSModelListResponse, TTSModelLoadRequest, |
| TTSModelLoadResponse, TTSModelStatusResponse, generate_model_id |
| ) |
|
|
| router = APIRouter() |
|
|
| |
| _tts_loading_status = { |
| "status": "idle", |
| "current_model_id": None, |
| "current_character_name": None, |
| "message": None, |
| "progress": 0.0 |
| } |
|
|
|
|
| @router.get("/models", response_model=TTSModelListResponse, summary="获取TTS模型列表") |
| async def list_tts_models(fastapi_request: Request): |
| """ |
| 获取所有可用的TTS模型列表 |
| """ |
| try: |
| all_configs = tts_config_registry.get_all_configs() |
| models = [] |
|
|
| |
| current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None) |
| current_tts_character_name = getattr(fastapi_request.app.state, "current_tts_character_name", None) |
|
|
| |
| if not current_tts_model_id: |
| default_config = tts_config_registry.get_default_config_for_system() |
| if default_config: |
| current_tts_model_id = generate_model_id(default_config.tts_type.value, default_config.character_name) |
| current_tts_character_name = default_config.character_name |
| logger.info(f"使用系统默认TTS模型: {current_tts_character_name} (ID: {current_tts_model_id})") |
|
|
| for config in all_configs: |
| |
| model_id = generate_model_id(config.tts_type.value, config.character_name) |
|
|
| |
| if config.is_model_complete(): |
| |
| if current_tts_model_id == model_id: |
| status = "downloaded" |
| elif (_tts_loading_status["status"] == "loading" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "downloading" |
| elif (_tts_loading_status["status"] == "failed" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "failed" |
| else: |
| status = "downloaded" |
| else: |
| |
| if (_tts_loading_status["status"] == "loading" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "downloading" |
| elif (_tts_loading_status["status"] == "failed" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "failed" |
| else: |
| status = "not_downloaded" |
|
|
| model_info = TTSModelInfo( |
| id=model_id, |
| character_name=config.character_name, |
| cover_image=config.cover_image, |
| description=config.description, |
| file_size=config.file_size, |
| is_chinese_voice=config.is_chinese_voice, |
| status=status |
| ) |
| models.append(model_info) |
|
|
| return TTSModelListResponse( |
| total=len(models), |
| models=models, |
| current_model_id=current_tts_model_id, |
| current_character_name=current_tts_character_name |
| ) |
|
|
| except Exception as e: |
| logger.error(f"获取TTS模型列表失败: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"获取TTS模型列表失败: {str(e)}") |
|
|
|
|
| @router.post("/models/load", response_model=TTSModelLoadResponse, summary="加载TTS模型") |
| async def load_tts_model( |
| request: TTSModelLoadRequest, |
| fastapi_request: Request, |
| background_tasks: BackgroundTasks, |
| ): |
| """ |
| 通过模型ID加载TTS模型 |
| """ |
| try: |
| if _tts_loading_status["status"] == "loading": |
| current_loading_model = _tts_loading_status["current_model_id"] |
| if current_loading_model == request.model_id: |
| return TTSModelLoadResponse( |
| success=True, |
| message=f"模型 {_tts_loading_status['current_character_name']} 正在加载中...", |
| model_id=request.model_id |
| ) |
| else: |
| return TTSModelLoadResponse( |
| success=False, |
| message="另一个模型正在加载中,请稍后重试", |
| model_id=request.model_id |
| ) |
|
|
| |
| config = _find_config_by_id(request.model_id) |
| if not config: |
| raise HTTPException(status_code=404, detail="模型ID不存在") |
|
|
| |
| if config.is_model_complete(): |
| |
| current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None) |
| if current_tts_model_id == request.model_id: |
| return TTSModelLoadResponse( |
| success=True, |
| message=f"模型 {config.character_name} 已是当前系统默认模型", |
| model_id=request.model_id |
| ) |
| else: |
| |
| return await _switch_tts_model(request, config, fastapi_request, background_tasks) |
| else: |
| |
| return await _download_and_load_tts_model(request, config, fastapi_request, background_tasks) |
|
|
| except HTTPException: |
| raise |
| except ValueError as e: |
| logger.warning(f"加载TTS模型失败 - 参数错误: {e}") |
| _tts_loading_status["status"] = "failed" |
| _tts_loading_status["message"] = str(e) |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| logger.error(f"加载TTS模型失败: {e}", exc_info=True) |
| _tts_loading_status["status"] = "failed" |
| _tts_loading_status["message"] = str(e) |
| return TTSModelLoadResponse( |
| success=False, |
| message=f"加载模型失败: {str(e)}", |
| model_id=request.model_id |
| ) |
|
|
|
|
| async def _switch_tts_model( |
| request: TTSModelLoadRequest, |
| config, |
| fastapi_request: Request, |
| background_tasks: BackgroundTasks |
| ) -> TTSModelLoadResponse: |
| """切换到已下载的TTS模型""" |
| |
| _tts_loading_status["status"] = "loading" |
| _tts_loading_status["current_model_id"] = request.model_id |
| _tts_loading_status["current_character_name"] = config.character_name |
| _tts_loading_status["message"] = "正在切换TTS模型..." |
| _tts_loading_status["progress"] = 0.0 |
|
|
| |
| background_tasks.add_task( |
| _switch_tts_model_background, |
| config, |
| request.model_id, |
| fastapi_request |
| ) |
|
|
| return TTSModelLoadResponse( |
| success=True, |
| message=f"模型 {config.character_name} 切换请求已提交,正在后台切换...", |
| model_id=request.model_id |
| ) |
|
|
|
|
| async def _download_and_load_tts_model( |
| request: TTSModelLoadRequest, |
| config, |
| fastapi_request: Request, |
| background_tasks: BackgroundTasks |
| ) -> TTSModelLoadResponse: |
| """下载并加载TTS模型""" |
| |
| _tts_loading_status["status"] = "loading" |
| _tts_loading_status["current_model_id"] = request.model_id |
| _tts_loading_status["current_character_name"] = config.character_name |
| _tts_loading_status["message"] = "正在下载TTS模型..." |
| _tts_loading_status["progress"] = 0.0 |
|
|
| |
| background_tasks.add_task( |
| _download_and_load_tts_model_background, |
| config, |
| request.model_id, |
| fastapi_request |
| ) |
|
|
| return TTSModelLoadResponse( |
| success=True, |
| message=f"模型 {config.character_name} 下载和加载请求已提交,正在后台处理...", |
| model_id=request.model_id |
| ) |
|
|
|
|
| async def _switch_tts_model_background(config, model_id: str, fastapi_request: Request): |
| """ |
| 后台切换TTS模型的实际逻辑 |
| """ |
| try: |
| logger.info(f"开始切换TTS模型: {config.character_name}") |
|
|
| |
| service_manager = getattr(fastapi_request.app.state, "service_manager", None) |
| if not service_manager: |
| raise RuntimeError("服务管理器未初始化") |
|
|
| _tts_loading_status["progress"] = 20.0 |
| _tts_loading_status["message"] = "正在停止当前TTS服务..." |
|
|
| |
| if service_manager.is_service_running("tts_audio_generator"): |
| service_manager._stop_single_service("tts_audio_generator") |
| logger.info("已停止当前TTS服务") |
|
|
| _tts_loading_status["progress"] = 50.0 |
| _tts_loading_status["message"] = "正在启动新的TTS服务..." |
|
|
| |
| new_tts_service_def = get_tts_audio_generator_service_definition(config) |
|
|
| |
| success = service_manager.start_service(new_tts_service_def) |
| if not success: |
| raise RuntimeError("新TTS服务启动失败") |
|
|
| _tts_loading_status["progress"] = 90.0 |
| _tts_loading_status["message"] = "正在验证服务状态..." |
|
|
| |
| fastapi_request.app.state.current_tts_model_id = model_id |
| fastapi_request.app.state.current_tts_character_name = config.character_name |
|
|
| |
| _tts_loading_status["status"] = "completed" |
| _tts_loading_status["progress"] = 100.0 |
| _tts_loading_status["message"] = f"成功切换到TTS模型: {config.character_name}" |
|
|
| logger.info(f"TTS模型切换成功: {config.character_name}") |
|
|
| except Exception as e: |
| logger.error(f"后台切换TTS模型失败: {e}", exc_info=True) |
| _tts_loading_status["status"] = "failed" |
| _tts_loading_status["message"] = str(e) |
| _tts_loading_status["progress"] = 0.0 |
|
|
|
|
| async def _download_and_load_tts_model_background(config, model_id: str, fastapi_request: Request): |
| """ |
| 后台下载并加载TTS模型的实际逻辑 |
| """ |
| try: |
| logger.info(f"开始下载并加载TTS模型: {config.character_name}") |
|
|
| _tts_loading_status["progress"] = 10.0 |
| _tts_loading_status["message"] = "正在准备下载..." |
|
|
| |
| _tts_loading_status["progress"] = 30.0 |
| _tts_loading_status["message"] = "正在下载模型文件..." |
|
|
| config.download_model() |
|
|
| _tts_loading_status["progress"] = 70.0 |
| _tts_loading_status["message"] = "正在验证模型文件..." |
|
|
| |
| if not config.is_model_complete(): |
| raise RuntimeError("模型文件下载不完整") |
|
|
| |
| service_manager = getattr(fastapi_request.app.state, "service_manager", None) |
| if not service_manager: |
| raise RuntimeError("服务管理器未初始化") |
|
|
| _tts_loading_status["progress"] = 80.0 |
| _tts_loading_status["message"] = "正在停止当前TTS服务..." |
|
|
| |
| if service_manager.is_service_running("tts_audio_generator"): |
| service_manager._stop_single_service("tts_audio_generator") |
| logger.info("已停止当前TTS服务") |
|
|
| _tts_loading_status["progress"] = 90.0 |
| _tts_loading_status["message"] = "正在启动新的TTS服务..." |
|
|
| |
| new_tts_service_def = get_tts_audio_generator_service_definition(config) |
|
|
| |
| success = service_manager.start_service(new_tts_service_def) |
| if not success: |
| raise RuntimeError("新TTS服务启动失败") |
|
|
| |
| fastapi_request.app.state.current_tts_model_id = model_id |
| fastapi_request.app.state.current_tts_character_name = config.character_name |
|
|
| |
| _tts_loading_status["status"] = "completed" |
| _tts_loading_status["progress"] = 100.0 |
| _tts_loading_status["message"] = f"成功下载并加载TTS模型: {config.character_name}" |
|
|
| logger.info(f"TTS模型下载并加载成功: {config.character_name}") |
|
|
| except Exception as e: |
| logger.error(f"后台下载并加载TTS模型失败: {e}", exc_info=True) |
| _tts_loading_status["status"] = "failed" |
| _tts_loading_status["message"] = str(e) |
| _tts_loading_status["progress"] = 0.0 |
|
|
|
|
| @router.get("/models/{model_id}/status", response_model=TTSModelStatusResponse, summary="获取TTS模型状态") |
| async def get_tts_model_status(model_id: str, fastapi_request: Request): |
| """ |
| 获取指定TTS模型的状态 |
| """ |
| try: |
| config = _find_config_by_id(model_id) |
| if not config: |
| raise HTTPException(status_code=404, detail="模型ID不存在") |
|
|
| |
| current_tts_model_id = getattr(fastapi_request.app.state, "current_tts_model_id", None) |
|
|
| |
| if config.is_model_complete(): |
| if current_tts_model_id == model_id: |
| status = "downloaded" |
| progress = 100.0 |
| elif (_tts_loading_status["status"] == "loading" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "downloading" |
| progress = _tts_loading_status["progress"] |
| elif (_tts_loading_status["status"] == "failed" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "failed" |
| progress = 0.0 |
| else: |
| status = "downloaded" |
| progress = 100.0 |
| else: |
| if (_tts_loading_status["status"] == "loading" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "downloading" |
| progress = _tts_loading_status["progress"] |
| elif (_tts_loading_status["status"] == "failed" and |
| _tts_loading_status["current_model_id"] == model_id): |
| status = "failed" |
| progress = 0.0 |
| else: |
| status = "not_downloaded" |
| progress = 0.0 |
|
|
| return TTSModelStatusResponse( |
| model_id=model_id, |
| status=status, |
| progress=progress |
| ) |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"获取TTS模型状态失败: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"获取模型状态失败: {str(e)}") |
|
|
|
|
| def _find_config_by_id(model_id: str) -> Optional: |
| """通过模型ID找到对应的配置""" |
| all_configs = tts_config_registry.get_all_configs() |
| for config in all_configs: |
| config_id = generate_model_id(config.tts_type.value, config.character_name) |
| if config_id == model_id: |
| return config |
| return None |
|
|
|
|
| @router.get("/models/{model_id}/reference-audio", summary="获取TTS模型参考音频") |
| async def get_tts_model_reference_audio(model_id: str): |
| """ |
| 通过模型ID获取TTS模型的参考音频文件 |
| """ |
| try: |
| |
| config = _find_config_by_id(model_id) |
| if not config: |
| raise HTTPException(status_code=404, detail="模型ID不存在") |
|
|
| |
| if not config.is_model_complete(): |
| raise HTTPException(status_code=400, detail="模型尚未下载,无法获取参考音频") |
|
|
| |
| reference_audio_path = '' |
| if hasattr(config, "reference_audio_path"): |
| reference_audio_path = config.reference_audio_path |
|
|
| |
| if reference_audio_path and not reference_audio_path.exists(): |
| raise HTTPException(status_code=404, detail="参考音频文件不存在") |
|
|
| |
| return FileResponse( |
| path=str(reference_audio_path), |
| media_type="audio/wav", |
| filename=f"{config.character_name}_reference.wav" |
| ) |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"获取TTS模型参考音频失败: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"获取参考音频失败: {str(e)}") |
|
|