Spaces:
Sleeping
Sleeping
| """ | |
| Main TTS Engine for Phone Announcements. | |
| This engine provides a unified interface for generating phone announcements | |
| using different TTS backends. It handles: | |
| - Backend management (loading, switching, unloading) | |
| - Audio generation with sensible defaults | |
| - Post-processing (background music, fades, normalization) | |
| - Caching for efficiency | |
| """ | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Type, Union | |
| import numpy as np | |
| from loguru import logger | |
| from .audio_processor import AudioProcessingConfig, AudioProcessor | |
| from .backends.base import BackendConfig, TTSBackend, TTSResult | |
| from .backends.chatterbox_backend import ChatterboxBackend | |
| from .cache import AudioCache, CacheConfig | |
| class EngineConfig: | |
| """Configuration for the TTS Engine.""" | |
| # Backend settings | |
| default_backend: str = "chatterbox" | |
| device: str = "auto" # "auto", "cuda", "mps", "cpu" | |
| # Default generation settings | |
| default_language: str = "de" # German for phone announcements | |
| # Audio processing defaults | |
| add_background_music: bool = False | |
| default_music: Optional[str] = None | |
| music_volume_db: float = -20.0 | |
| fade_in_ms: int = 500 | |
| fade_out_ms: int = 500 | |
| # Caching | |
| enable_cache: bool = True | |
| local_cache_dir: Optional[str] = None | |
| hf_cache_repo: Optional[str] = None | |
| class TTSEngine: | |
| """ | |
| Main TTS Engine for generating phone announcements. | |
| Usage: | |
| # Simple usage with defaults | |
| engine = TTSEngine() | |
| audio = engine.generate("Welcome to our service.") | |
| # With voice cloning | |
| audio = engine.generate( | |
| "Welcome to our service.", | |
| voice_audio="path/to/reference.wav" | |
| ) | |
| # Switch backend | |
| engine.set_backend("chatterbox") | |
| audio = engine.generate("Welcome to our service.", language="en") | |
| """ | |
| # Registry of available backends | |
| _backend_registry: dict[str, Type[TTSBackend]] = { | |
| "chatterbox": ChatterboxBackend, | |
| } | |
| def __init__(self, config: Optional[EngineConfig] = None): | |
| self.config = config or EngineConfig() | |
| # Initialize components | |
| self._backends: dict[str, TTSBackend] = {} | |
| self._current_backend_name: str = self.config.default_backend | |
| # Audio processor | |
| self._processor = AudioProcessor( | |
| AudioProcessingConfig( | |
| music_volume_db=self.config.music_volume_db, | |
| fade_in_ms=self.config.fade_in_ms, | |
| fade_out_ms=self.config.fade_out_ms, | |
| ) | |
| ) | |
| # Cache | |
| self._cache = AudioCache( | |
| CacheConfig( | |
| enabled=self.config.enable_cache, | |
| local_cache_dir=self.config.local_cache_dir, | |
| hf_repo_id=self.config.hf_cache_repo, | |
| ) | |
| ) | |
| def register_backend(cls, name: str, backend_class: Type[TTSBackend]) -> None: | |
| """Register a new backend class.""" | |
| cls._backend_registry[name] = backend_class | |
| logger.info(f"Registered backend: {name}") | |
| def available_backends(cls) -> list[str]: | |
| """List available backend names.""" | |
| return list(cls._backend_registry.keys()) | |
| def _get_backend(self, name: Optional[str] = None) -> TTSBackend: | |
| """Get or create a backend instance.""" | |
| name = name or self._current_backend_name | |
| if name not in self._backend_registry: | |
| available = ", ".join(self._backend_registry.keys()) | |
| raise ValueError(f"Unknown backend '{name}'. Available: {available}") | |
| if name not in self._backends: | |
| backend_config = BackendConfig(device=self.config.device) | |
| self._backends[name] = self._backend_registry[name](backend_config) | |
| return self._backends[name] | |
| def current_backend(self) -> TTSBackend: | |
| """Get the current active backend.""" | |
| return self._get_backend() | |
| def set_backend(self, name: str) -> None: | |
| """Switch to a different backend.""" | |
| if name not in self._backend_registry: | |
| available = ", ".join(self._backend_registry.keys()) | |
| raise ValueError(f"Unknown backend '{name}'. Available: {available}") | |
| self._current_backend_name = name | |
| logger.info(f"Switched to backend: {name}") | |
| def load_backend(self, name: Optional[str] = None) -> None: | |
| """Pre-load a backend's model.""" | |
| backend = self._get_backend(name) | |
| if not backend.is_loaded: | |
| backend.load() | |
| def unload_backend(self, name: Optional[str] = None) -> None: | |
| """Unload a backend's model to free memory.""" | |
| backend = self._get_backend(name) | |
| if backend.is_loaded: | |
| backend.unload() | |
| def get_supported_languages(self, backend: Optional[str] = None) -> dict[str, str]: | |
| """Get supported languages for a backend.""" | |
| return self._get_backend(backend).supported_languages | |
| def generate( | |
| self, | |
| text: str, | |
| language: Optional[str] = None, | |
| voice_audio: Optional[str] = None, | |
| background_music: Optional[str] = None, | |
| output_path: Optional[str] = None, | |
| use_cache: bool = True, | |
| split_sentences: bool = True, | |
| max_chars_per_chunk: int = 250, | |
| **kwargs, | |
| ) -> Union[bytes, str, tuple[int, np.ndarray]]: | |
| """ | |
| Generate a phone announcement. | |
| Args: | |
| text: Text to synthesize | |
| language: Language code (default: "de") | |
| voice_audio: Path/URL to reference audio for voice cloning | |
| background_music: Name/path of background music file | |
| output_path: Optional path to save output file | |
| use_cache: Whether to use caching (default: True) | |
| split_sentences: Auto-split long text into sentences (default: True) | |
| max_chars_per_chunk: Max chars per chunk when splitting (default: 250) | |
| **kwargs: Additional backend-specific parameters | |
| Returns: | |
| - If output_path: path to saved file | |
| - If no output_path and no background_music: tuple(sample_rate, audio_array) for Gradio | |
| - Otherwise: MP3 bytes | |
| """ | |
| language = language or self.config.default_language | |
| backend = self.current_backend | |
| # Generate voice ID for caching. | |
| # - Voice cloning: derive from reference audio when available | |
| # - If no reference audio: use "default" | |
| if voice_audio: | |
| voice_id = ( | |
| Path(voice_audio).stem | |
| if os.path.exists(voice_audio or "") | |
| else "custom" | |
| ) | |
| else: | |
| voice_id = "default" | |
| # Check cache | |
| if use_cache and self._cache.config.enabled: | |
| cached = self._cache.get(text, voice_id, backend.name) | |
| if cached: | |
| logger.info("Using cached audio") | |
| if output_path: | |
| Path(output_path).write_bytes(cached) | |
| return output_path | |
| return cached | |
| # Generate audio (use sentence splitting for long texts) | |
| logger.info(f"Generating TTS: backend={backend.name}, lang={language}") | |
| if split_sentences and len(text) > max_chars_per_chunk: | |
| logger.info(f"Text is {len(text)} chars, splitting into sentences") | |
| result = backend.generate_long( | |
| text=text, | |
| language=language, | |
| voice_audio_path=voice_audio, | |
| max_chars_per_chunk=max_chars_per_chunk, | |
| **kwargs, | |
| ) | |
| else: | |
| result = backend.generate( | |
| text=text, language=language, voice_audio_path=voice_audio, **kwargs | |
| ) | |
| # Determine if we need post-processing | |
| use_music = background_music or ( | |
| self.config.add_background_music and self.config.default_music | |
| ) | |
| music_path = background_music or self.config.default_music | |
| if use_music or output_path: | |
| # Process audio with pydub | |
| processed = self._processor.process( | |
| audio=result.audio, | |
| sample_rate=result.sample_rate, | |
| output_path=output_path, | |
| background_music_path=music_path if use_music else None, | |
| ) | |
| # Cache if appropriate | |
| if use_cache and isinstance(processed, bytes): | |
| duration = len(result.audio) / result.sample_rate | |
| self._cache.set(text, voice_id, backend.name, processed, duration) | |
| return processed | |
| else: | |
| # Return raw audio for Gradio (sample_rate, audio_array) | |
| return (result.sample_rate, result.audio) | |
| def generate_raw( | |
| self, | |
| text: str, | |
| language: Optional[str] = None, | |
| voice_audio: Optional[str] = None, | |
| split_sentences: bool = True, | |
| max_chars_per_chunk: int = 250, | |
| **kwargs, | |
| ) -> TTSResult: | |
| """ | |
| Generate raw audio without post-processing. | |
| Args: | |
| text: Text to synthesize | |
| language: Language code (default from config) | |
| voice_audio: Path/URL to reference audio for voice cloning | |
| split_sentences: Auto-split long text into sentences (default: True) | |
| max_chars_per_chunk: Max chars per chunk when splitting (default: 250) | |
| **kwargs: Additional backend-specific parameters | |
| Returns: | |
| TTSResult with audio array and sample rate | |
| """ | |
| language = language or self.config.default_language | |
| backend = self.current_backend | |
| if split_sentences and len(text) > max_chars_per_chunk: | |
| logger.info(f"Text is {len(text)} chars, splitting into sentences") | |
| return backend.generate_long( | |
| text=text, | |
| language=language, | |
| voice_audio_path=voice_audio, | |
| max_chars_per_chunk=max_chars_per_chunk, | |
| **kwargs, | |
| ) | |
| else: | |
| return backend.generate( | |
| text=text, language=language, voice_audio_path=voice_audio, **kwargs | |
| ) | |
| def list_background_music(self) -> list[str]: | |
| """List available background music files.""" | |
| return self._processor.list_available_music() | |
| def clear_cache(self) -> int: | |
| """Clear the local audio cache. Returns number of files deleted.""" | |
| return self._cache.clear_local() | |