| """ |
| 原生Gemini TTS聊天服务扩展 |
| 继承自原始聊天服务,添加原生Gemini TTS支持(单人和多人),保持向后兼容 |
| """ |
|
|
| import time |
| import datetime |
| from typing import Any, Dict |
| from app.service.chat.gemini_chat_service import GeminiChatService |
| from app.service.tts.native.tts_response_handler import TTSResponseHandler |
| from app.domain.gemini_models import GeminiRequest |
| from app.log.logger import get_gemini_logger |
| from app.database.services import add_request_log, add_error_log |
|
|
| logger = get_gemini_logger() |
|
|
|
|
| class TTSGeminiChatService(GeminiChatService): |
| """ |
| 支持TTS的Gemini聊天服务 |
| 继承自原始的GeminiChatService,添加TTS功能 |
| """ |
|
|
| def __init__(self, base_url: str, key_manager): |
| """ |
| 初始化TTS聊天服务 |
| """ |
| super().__init__(base_url, key_manager) |
| |
| self.response_handler = TTSResponseHandler() |
| logger.info("TTS Gemini Chat Service initialized with multi-speaker TTS support") |
|
|
| async def generate_content( |
| self, model: str, request: GeminiRequest, api_key: str |
| ) -> Dict[str, Any]: |
| """ |
| 生成内容,支持TTS |
| """ |
| try: |
| |
| logger.info(f"TTS request model: {model}") |
| logger.info(f"TTS request generationConfig: {request.generationConfig}") |
|
|
| |
| if "tts" in model.lower(): |
| logger.info("Detected TTS model, applying TTS-specific processing") |
| |
| response = await self._handle_tts_request(model, request, api_key) |
| return response |
| else: |
| |
| response = await super().generate_content(model, request, api_key) |
| return response |
| except Exception as e: |
| logger.error(f"TTS API call failed with error: {e}") |
| raise |
|
|
| async def _handle_tts_request(self, model: str, request: GeminiRequest, api_key: str) -> Dict[str, Any]: |
| """ |
| 处理TTS特定的请求,包含完整的日志记录功能 |
| """ |
| |
| start_time = time.perf_counter() |
| request_datetime = datetime.datetime.now() |
| is_success = False |
| status_code = None |
|
|
| try: |
| |
| from app.service.chat.gemini_chat_service import _filter_empty_parts |
|
|
| request_dict = request.model_dump(exclude_none=False) |
|
|
| |
| payload = { |
| "contents": _filter_empty_parts(request_dict.get("contents", [])), |
| "generationConfig": request_dict.get("generationConfig", {}), |
| } |
|
|
| |
| if request_dict.get("systemInstruction"): |
| payload["systemInstruction"] = request_dict.get("systemInstruction") |
|
|
| |
| if payload["generationConfig"] is None: |
| payload["generationConfig"] = {} |
|
|
| |
| if request.generationConfig: |
| |
| if request.generationConfig.responseModalities: |
| payload["generationConfig"]["responseModalities"] = request.generationConfig.responseModalities |
| logger.info(f"Added responseModalities: {request.generationConfig.responseModalities}") |
|
|
| if request.generationConfig.speechConfig: |
| payload["generationConfig"]["speechConfig"] = request.generationConfig.speechConfig |
| logger.info(f"Added speechConfig: {request.generationConfig.speechConfig}") |
| else: |
| logger.warning("No generationConfig found in request, TTS fields may be missing") |
|
|
| logger.info(f"TTS payload before API call: {payload}") |
|
|
| |
| response = await self.api_client.generate_content(payload, model, api_key) |
|
|
| |
| is_success = True |
| status_code = 200 |
|
|
| |
| return self.response_handler.handle_response(response, model, False, None) |
|
|
| except Exception as e: |
| |
| is_success = False |
| error_msg = str(e) |
|
|
| |
| import re |
| match = re.search(r"status code (\d+)", error_msg) |
| if match: |
| status_code = int(match.group(1)) |
| else: |
| status_code = 500 |
|
|
| |
| await add_error_log( |
| gemini_key=api_key, |
| model_name=model, |
| error_type="tts-api-error", |
| error_log=error_msg, |
| error_code=status_code, |
| request_msg=request.model_dump(exclude_none=False) |
| ) |
|
|
| logger.error(f"TTS API call failed: {error_msg}") |
| raise |
|
|
| finally: |
| |
| end_time = time.perf_counter() |
| latency_ms = int((end_time - start_time) * 1000) |
|
|
| await add_request_log( |
| model_name=model, |
| api_key=api_key, |
| is_success=is_success, |
| status_code=status_code, |
| latency_ms=latency_ms, |
| request_time=request_datetime |
| ) |
|
|