flozi00's picture
chatterbox only
0849031
"""
Main TTS Engine for Phone Announcements.
This engine provides a unified interface for generating phone announcements
using different TTS backends. It handles:
- Backend management (loading, switching, unloading)
- Audio generation with sensible defaults
- Post-processing (background music, fades, normalization)
- Caching for efficiency
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Type, Union
import numpy as np
from loguru import logger
from .audio_processor import AudioProcessingConfig, AudioProcessor
from .backends.base import BackendConfig, TTSBackend, TTSResult
from .backends.chatterbox_backend import ChatterboxBackend
from .cache import AudioCache, CacheConfig
@dataclass
class EngineConfig:
"""Configuration for the TTS Engine."""
# Backend settings
default_backend: str = "chatterbox"
device: str = "auto" # "auto", "cuda", "mps", "cpu"
# Default generation settings
default_language: str = "de" # German for phone announcements
# Audio processing defaults
add_background_music: bool = False
default_music: Optional[str] = None
music_volume_db: float = -20.0
fade_in_ms: int = 500
fade_out_ms: int = 500
# Caching
enable_cache: bool = True
local_cache_dir: Optional[str] = None
hf_cache_repo: Optional[str] = None
class TTSEngine:
"""
Main TTS Engine for generating phone announcements.
Usage:
# Simple usage with defaults
engine = TTSEngine()
audio = engine.generate("Welcome to our service.")
# With voice cloning
audio = engine.generate(
"Welcome to our service.",
voice_audio="path/to/reference.wav"
)
# Switch backend
engine.set_backend("chatterbox")
audio = engine.generate("Welcome to our service.", language="en")
"""
# Registry of available backends
_backend_registry: dict[str, Type[TTSBackend]] = {
"chatterbox": ChatterboxBackend,
}
def __init__(self, config: Optional[EngineConfig] = None):
self.config = config or EngineConfig()
# Initialize components
self._backends: dict[str, TTSBackend] = {}
self._current_backend_name: str = self.config.default_backend
# Audio processor
self._processor = AudioProcessor(
AudioProcessingConfig(
music_volume_db=self.config.music_volume_db,
fade_in_ms=self.config.fade_in_ms,
fade_out_ms=self.config.fade_out_ms,
)
)
# Cache
self._cache = AudioCache(
CacheConfig(
enabled=self.config.enable_cache,
local_cache_dir=self.config.local_cache_dir,
hf_repo_id=self.config.hf_cache_repo,
)
)
@classmethod
def register_backend(cls, name: str, backend_class: Type[TTSBackend]) -> None:
"""Register a new backend class."""
cls._backend_registry[name] = backend_class
logger.info(f"Registered backend: {name}")
@classmethod
def available_backends(cls) -> list[str]:
"""List available backend names."""
return list(cls._backend_registry.keys())
def _get_backend(self, name: Optional[str] = None) -> TTSBackend:
"""Get or create a backend instance."""
name = name or self._current_backend_name
if name not in self._backend_registry:
available = ", ".join(self._backend_registry.keys())
raise ValueError(f"Unknown backend '{name}'. Available: {available}")
if name not in self._backends:
backend_config = BackendConfig(device=self.config.device)
self._backends[name] = self._backend_registry[name](backend_config)
return self._backends[name]
@property
def current_backend(self) -> TTSBackend:
"""Get the current active backend."""
return self._get_backend()
def set_backend(self, name: str) -> None:
"""Switch to a different backend."""
if name not in self._backend_registry:
available = ", ".join(self._backend_registry.keys())
raise ValueError(f"Unknown backend '{name}'. Available: {available}")
self._current_backend_name = name
logger.info(f"Switched to backend: {name}")
def load_backend(self, name: Optional[str] = None) -> None:
"""Pre-load a backend's model."""
backend = self._get_backend(name)
if not backend.is_loaded:
backend.load()
def unload_backend(self, name: Optional[str] = None) -> None:
"""Unload a backend's model to free memory."""
backend = self._get_backend(name)
if backend.is_loaded:
backend.unload()
def get_supported_languages(self, backend: Optional[str] = None) -> dict[str, str]:
"""Get supported languages for a backend."""
return self._get_backend(backend).supported_languages
def generate(
self,
text: str,
language: Optional[str] = None,
voice_audio: Optional[str] = None,
background_music: Optional[str] = None,
output_path: Optional[str] = None,
use_cache: bool = True,
split_sentences: bool = True,
max_chars_per_chunk: int = 250,
**kwargs,
) -> Union[bytes, str, tuple[int, np.ndarray]]:
"""
Generate a phone announcement.
Args:
text: Text to synthesize
language: Language code (default: "de")
voice_audio: Path/URL to reference audio for voice cloning
background_music: Name/path of background music file
output_path: Optional path to save output file
use_cache: Whether to use caching (default: True)
split_sentences: Auto-split long text into sentences (default: True)
max_chars_per_chunk: Max chars per chunk when splitting (default: 250)
**kwargs: Additional backend-specific parameters
Returns:
- If output_path: path to saved file
- If no output_path and no background_music: tuple(sample_rate, audio_array) for Gradio
- Otherwise: MP3 bytes
"""
language = language or self.config.default_language
backend = self.current_backend
# Generate voice ID for caching.
# - Voice cloning: derive from reference audio when available
# - If no reference audio: use "default"
if voice_audio:
voice_id = (
Path(voice_audio).stem
if os.path.exists(voice_audio or "")
else "custom"
)
else:
voice_id = "default"
# Check cache
if use_cache and self._cache.config.enabled:
cached = self._cache.get(text, voice_id, backend.name)
if cached:
logger.info("Using cached audio")
if output_path:
Path(output_path).write_bytes(cached)
return output_path
return cached
# Generate audio (use sentence splitting for long texts)
logger.info(f"Generating TTS: backend={backend.name}, lang={language}")
if split_sentences and len(text) > max_chars_per_chunk:
logger.info(f"Text is {len(text)} chars, splitting into sentences")
result = backend.generate_long(
text=text,
language=language,
voice_audio_path=voice_audio,
max_chars_per_chunk=max_chars_per_chunk,
**kwargs,
)
else:
result = backend.generate(
text=text, language=language, voice_audio_path=voice_audio, **kwargs
)
# Determine if we need post-processing
use_music = background_music or (
self.config.add_background_music and self.config.default_music
)
music_path = background_music or self.config.default_music
if use_music or output_path:
# Process audio with pydub
processed = self._processor.process(
audio=result.audio,
sample_rate=result.sample_rate,
output_path=output_path,
background_music_path=music_path if use_music else None,
)
# Cache if appropriate
if use_cache and isinstance(processed, bytes):
duration = len(result.audio) / result.sample_rate
self._cache.set(text, voice_id, backend.name, processed, duration)
return processed
else:
# Return raw audio for Gradio (sample_rate, audio_array)
return (result.sample_rate, result.audio)
def generate_raw(
self,
text: str,
language: Optional[str] = None,
voice_audio: Optional[str] = None,
split_sentences: bool = True,
max_chars_per_chunk: int = 250,
**kwargs,
) -> TTSResult:
"""
Generate raw audio without post-processing.
Args:
text: Text to synthesize
language: Language code (default from config)
voice_audio: Path/URL to reference audio for voice cloning
split_sentences: Auto-split long text into sentences (default: True)
max_chars_per_chunk: Max chars per chunk when splitting (default: 250)
**kwargs: Additional backend-specific parameters
Returns:
TTSResult with audio array and sample rate
"""
language = language or self.config.default_language
backend = self.current_backend
if split_sentences and len(text) > max_chars_per_chunk:
logger.info(f"Text is {len(text)} chars, splitting into sentences")
return backend.generate_long(
text=text,
language=language,
voice_audio_path=voice_audio,
max_chars_per_chunk=max_chars_per_chunk,
**kwargs,
)
else:
return backend.generate(
text=text, language=language, voice_audio_path=voice_audio, **kwargs
)
def list_background_music(self) -> list[str]:
"""List available background music files."""
return self._processor.list_available_music()
def clear_cache(self) -> int:
"""Clear the local audio cache. Returns number of files deleted."""
return self._cache.clear_local()