Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| import threading | |
| from typing import AsyncIterator | |
| from google.api_core import exceptions as gcp_exceptions | |
| from google.cloud import texttospeech | |
| from src.config import GEMINI_TTS_MODELS, GEMINI_TTS_VOICE, GEMINI_TTS_LANGUAGE | |
| from src.google_auth import get_google_credentials | |
| logger = logging.getLogger(__name__) | |
| _credentials = get_google_credentials() | |
| GEMINI_SAMPLE_RATE = 24000 # Cloud TTS Gemini model outputs PCM Linear16 at 24kHz | |
| class _ModelRotator: | |
| def __init__(self, models: list[str]) -> None: | |
| self._models = models | |
| self._index = 0 | |
| self._lock = threading.Lock() | |
| def get_rotation(self) -> list[str]: | |
| """Returns models starting from current index, then advances for next call.""" | |
| with self._lock: | |
| n = len(self._models) | |
| start = self._index % n | |
| self._index += 1 | |
| return self._models[start:] + self._models[:start] | |
| _rotator = _ModelRotator(GEMINI_TTS_MODELS) | |
| async def synthesize_stream(text: str) -> AsyncIterator[bytes]: | |
| """Calls Cloud TTS Gemini model and yields raw PCM Linear16 chunks at 24kHz. | |
| Rotates across GEMINI_TTS_MODELS round-robin and falls back to the next model | |
| on ResourceExhausted (quota exceeded). | |
| """ | |
| models_to_try = _rotator.get_rotation() | |
| last_exc: Exception | None = None | |
| for model in models_to_try: | |
| logger.info("Gemini TTS [%s]: text_len=%d, text=%r", model, len(text), text) | |
| def _collect(model_name: str = model) -> list[bytes]: | |
| client = texttospeech.TextToSpeechClient(credentials=_credentials) | |
| config_req = texttospeech.StreamingSynthesizeRequest( | |
| streaming_config=texttospeech.StreamingSynthesizeConfig( | |
| voice=texttospeech.VoiceSelectionParams( | |
| name=GEMINI_TTS_VOICE, | |
| language_code=GEMINI_TTS_LANGUAGE, | |
| model_name=model_name, | |
| ) | |
| ) | |
| ) | |
| def _gen(): | |
| yield config_req | |
| yield texttospeech.StreamingSynthesizeRequest( | |
| input=texttospeech.StreamingSynthesisInput(text=text) | |
| ) | |
| all_chunks = [r.audio_content for r in client.streaming_synthesize(_gen())] | |
| chunks = [c for c in all_chunks if c] | |
| total_bytes = sum(len(c) for c in chunks) | |
| logger.info( | |
| "Gemini TTS [%s]: received %d chunks (%d empty skipped), %d bytes total, aligned=%s, first32=%s", | |
| model_name, | |
| len(chunks), | |
| len(all_chunks) - len(chunks), | |
| total_bytes, | |
| total_bytes % 2 == 0, | |
| chunks[0][:32].hex() if chunks else "N/A", | |
| ) | |
| return chunks | |
| try: | |
| chunks = await asyncio.to_thread(_collect) | |
| for chunk in chunks: | |
| yield chunk | |
| return | |
| except gcp_exceptions.ResourceExhausted: | |
| logger.warning("Gemini TTS quota exhausted for model '%s', trying next model", model) | |
| last_exc = gcp_exceptions.ResourceExhausted(f"Quota exhausted on {model}") | |
| continue | |
| except Exception: | |
| raise | |
| raise last_exc # type: ignore[misc] | |