File size: 3,334 Bytes
986403e
e75bac4
e0ee8f7
e75bac4
986403e
e0ee8f7
986403e
 
e0ee8f7
 
e75bac4
 
 
e0ee8f7
 
986403e
e75bac4
 
e0ee8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e75bac4
e0ee8f7
 
 
 
 
 
 
 
 
c2e783d
e0ee8f7
 
 
 
 
 
 
 
 
 
986403e
 
 
e0ee8f7
 
 
 
 
 
c2e783d
 
 
e0ee8f7
c2e783d
 
 
 
 
 
 
986403e
e0ee8f7
986403e
e0ee8f7
 
 
 
 
 
 
 
 
 
 
986403e
e0ee8f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import asyncio
import logging
import threading
from typing import AsyncIterator

from google.api_core import exceptions as gcp_exceptions
from google.cloud import texttospeech

from src.config import GEMINI_TTS_MODELS, GEMINI_TTS_VOICE, GEMINI_TTS_LANGUAGE
from src.google_auth import get_google_credentials

logger = logging.getLogger(__name__)

_credentials = get_google_credentials()

GEMINI_SAMPLE_RATE = 24000  # Cloud TTS Gemini model outputs PCM Linear16 at 24kHz


class _ModelRotator:
    def __init__(self, models: list[str]) -> None:
        self._models = models
        self._index = 0
        self._lock = threading.Lock()

    def get_rotation(self) -> list[str]:
        """Returns models starting from current index, then advances for next call."""
        with self._lock:
            n = len(self._models)
            start = self._index % n
            self._index += 1
            return self._models[start:] + self._models[:start]


_rotator = _ModelRotator(GEMINI_TTS_MODELS)


async def synthesize_stream(text: str) -> AsyncIterator[bytes]:
    """Calls Cloud TTS Gemini model and yields raw PCM Linear16 chunks at 24kHz.

    Rotates across GEMINI_TTS_MODELS round-robin and falls back to the next model
    on ResourceExhausted (quota exceeded).
    """
    models_to_try = _rotator.get_rotation()
    last_exc: Exception | None = None

    for model in models_to_try:
        logger.info("Gemini TTS [%s]: text_len=%d, text=%r", model, len(text), text)

        def _collect(model_name: str = model) -> list[bytes]:
            client = texttospeech.TextToSpeechClient(credentials=_credentials)
            config_req = texttospeech.StreamingSynthesizeRequest(
                streaming_config=texttospeech.StreamingSynthesizeConfig(
                    voice=texttospeech.VoiceSelectionParams(
                        name=GEMINI_TTS_VOICE,
                        language_code=GEMINI_TTS_LANGUAGE,
                        model_name=model_name,
                    )
                )
            )

            def _gen():
                yield config_req
                yield texttospeech.StreamingSynthesizeRequest(
                    input=texttospeech.StreamingSynthesisInput(text=text)
                )

            all_chunks = [r.audio_content for r in client.streaming_synthesize(_gen())]
            chunks = [c for c in all_chunks if c]
            total_bytes = sum(len(c) for c in chunks)
            logger.info(
                "Gemini TTS [%s]: received %d chunks (%d empty skipped), %d bytes total, aligned=%s, first32=%s",
                model_name,
                len(chunks),
                len(all_chunks) - len(chunks),
                total_bytes,
                total_bytes % 2 == 0,
                chunks[0][:32].hex() if chunks else "N/A",
            )
            return chunks

        try:
            chunks = await asyncio.to_thread(_collect)
            for chunk in chunks:
                yield chunk
            return
        except gcp_exceptions.ResourceExhausted:
            logger.warning("Gemini TTS quota exhausted for model '%s', trying next model", model)
            last_exc = gcp_exceptions.ResourceExhausted(f"Quota exhausted on {model}")
            continue
        except Exception:
            raise

    raise last_exc  # type: ignore[misc]