Demo-Voice-Agent-Service / src /tts /gemini_client.py
ishaq101's picture
fixing error stt and tts, empty chunk audio
c2e783d
import asyncio
import logging
import threading
from typing import AsyncIterator
from google.api_core import exceptions as gcp_exceptions
from google.cloud import texttospeech
from src.config import GEMINI_TTS_MODELS, GEMINI_TTS_VOICE, GEMINI_TTS_LANGUAGE
from src.google_auth import get_google_credentials
logger = logging.getLogger(__name__)
_credentials = get_google_credentials()
GEMINI_SAMPLE_RATE = 24000 # Cloud TTS Gemini model outputs PCM Linear16 at 24kHz
class _ModelRotator:
def __init__(self, models: list[str]) -> None:
self._models = models
self._index = 0
self._lock = threading.Lock()
def get_rotation(self) -> list[str]:
"""Returns models starting from current index, then advances for next call."""
with self._lock:
n = len(self._models)
start = self._index % n
self._index += 1
return self._models[start:] + self._models[:start]
_rotator = _ModelRotator(GEMINI_TTS_MODELS)
async def synthesize_stream(text: str) -> AsyncIterator[bytes]:
"""Calls Cloud TTS Gemini model and yields raw PCM Linear16 chunks at 24kHz.
Rotates across GEMINI_TTS_MODELS round-robin and falls back to the next model
on ResourceExhausted (quota exceeded).
"""
models_to_try = _rotator.get_rotation()
last_exc: Exception | None = None
for model in models_to_try:
logger.info("Gemini TTS [%s]: text_len=%d, text=%r", model, len(text), text)
def _collect(model_name: str = model) -> list[bytes]:
client = texttospeech.TextToSpeechClient(credentials=_credentials)
config_req = texttospeech.StreamingSynthesizeRequest(
streaming_config=texttospeech.StreamingSynthesizeConfig(
voice=texttospeech.VoiceSelectionParams(
name=GEMINI_TTS_VOICE,
language_code=GEMINI_TTS_LANGUAGE,
model_name=model_name,
)
)
)
def _gen():
yield config_req
yield texttospeech.StreamingSynthesizeRequest(
input=texttospeech.StreamingSynthesisInput(text=text)
)
all_chunks = [r.audio_content for r in client.streaming_synthesize(_gen())]
chunks = [c for c in all_chunks if c]
total_bytes = sum(len(c) for c in chunks)
logger.info(
"Gemini TTS [%s]: received %d chunks (%d empty skipped), %d bytes total, aligned=%s, first32=%s",
model_name,
len(chunks),
len(all_chunks) - len(chunks),
total_bytes,
total_bytes % 2 == 0,
chunks[0][:32].hex() if chunks else "N/A",
)
return chunks
try:
chunks = await asyncio.to_thread(_collect)
for chunk in chunks:
yield chunk
return
except gcp_exceptions.ResourceExhausted:
logger.warning("Gemini TTS quota exhausted for model '%s', trying next model", model)
last_exc = gcp_exceptions.ResourceExhausted(f"Quota exhausted on {model}")
continue
except Exception:
raise
raise last_exc # type: ignore[misc]