import asyncio import logging import threading from typing import AsyncIterator from google.api_core import exceptions as gcp_exceptions from google.cloud import texttospeech from src.config import GEMINI_TTS_MODELS, GEMINI_TTS_VOICE, GEMINI_TTS_LANGUAGE from src.google_auth import get_google_credentials logger = logging.getLogger(__name__) _credentials = get_google_credentials() GEMINI_SAMPLE_RATE = 24000 # Cloud TTS Gemini model outputs PCM Linear16 at 24kHz class _ModelRotator: def __init__(self, models: list[str]) -> None: self._models = models self._index = 0 self._lock = threading.Lock() def get_rotation(self) -> list[str]: """Returns models starting from current index, then advances for next call.""" with self._lock: n = len(self._models) start = self._index % n self._index += 1 return self._models[start:] + self._models[:start] _rotator = _ModelRotator(GEMINI_TTS_MODELS) async def synthesize_stream(text: str) -> AsyncIterator[bytes]: """Calls Cloud TTS Gemini model and yields raw PCM Linear16 chunks at 24kHz. Rotates across GEMINI_TTS_MODELS round-robin and falls back to the next model on ResourceExhausted (quota exceeded). """ models_to_try = _rotator.get_rotation() last_exc: Exception | None = None for model in models_to_try: logger.info("Gemini TTS [%s]: text_len=%d, text=%r", model, len(text), text) def _collect(model_name: str = model) -> list[bytes]: client = texttospeech.TextToSpeechClient(credentials=_credentials) config_req = texttospeech.StreamingSynthesizeRequest( streaming_config=texttospeech.StreamingSynthesizeConfig( voice=texttospeech.VoiceSelectionParams( name=GEMINI_TTS_VOICE, language_code=GEMINI_TTS_LANGUAGE, model_name=model_name, ) ) ) def _gen(): yield config_req yield texttospeech.StreamingSynthesizeRequest( input=texttospeech.StreamingSynthesisInput(text=text) ) all_chunks = [r.audio_content for r in client.streaming_synthesize(_gen())] chunks = [c for c in all_chunks if c] total_bytes = sum(len(c) for c in chunks) logger.info( "Gemini TTS [%s]: received %d chunks (%d empty skipped), %d bytes total, aligned=%s, first32=%s", model_name, len(chunks), len(all_chunks) - len(chunks), total_bytes, total_bytes % 2 == 0, chunks[0][:32].hex() if chunks else "N/A", ) return chunks try: chunks = await asyncio.to_thread(_collect) for chunk in chunks: yield chunk return except gcp_exceptions.ResourceExhausted: logger.warning("Gemini TTS quota exhausted for model '%s', trying next model", model) last_exc = gcp_exceptions.ResourceExhausted(f"Quota exhausted on {model}") continue except Exception: raise raise last_exc # type: ignore[misc]