Spaces:
Sleeping
Sleeping
File size: 10,580 Bytes
a86cdfa d0b9ec6 a86cdfa 0849031 a86cdfa d0b9ec6 a86cdfa d0b9ec6 a86cdfa 0849031 a86cdfa 87b184a a86cdfa 87b184a a86cdfa d63c0fa 0849031 d63c0fa a86cdfa d63c0fa a86cdfa 87b184a a86cdfa 87b184a a86cdfa 87b184a a86cdfa 87b184a a86cdfa 87b184a a86cdfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
"""
Main TTS Engine for Phone Announcements.
This engine provides a unified interface for generating phone announcements
using different TTS backends. It handles:
- Backend management (loading, switching, unloading)
- Audio generation with sensible defaults
- Post-processing (background music, fades, normalization)
- Caching for efficiency
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Type, Union
import numpy as np
from loguru import logger
from .audio_processor import AudioProcessingConfig, AudioProcessor
from .backends.base import BackendConfig, TTSBackend, TTSResult
from .backends.chatterbox_backend import ChatterboxBackend
from .cache import AudioCache, CacheConfig
@dataclass
class EngineConfig:
"""Configuration for the TTS Engine."""
# Backend settings
default_backend: str = "chatterbox"
device: str = "auto" # "auto", "cuda", "mps", "cpu"
# Default generation settings
default_language: str = "de" # German for phone announcements
# Audio processing defaults
add_background_music: bool = False
default_music: Optional[str] = None
music_volume_db: float = -20.0
fade_in_ms: int = 500
fade_out_ms: int = 500
# Caching
enable_cache: bool = True
local_cache_dir: Optional[str] = None
hf_cache_repo: Optional[str] = None
class TTSEngine:
"""
Main TTS Engine for generating phone announcements.
Usage:
# Simple usage with defaults
engine = TTSEngine()
audio = engine.generate("Welcome to our service.")
# With voice cloning
audio = engine.generate(
"Welcome to our service.",
voice_audio="path/to/reference.wav"
)
# Switch backend
engine.set_backend("chatterbox")
audio = engine.generate("Welcome to our service.", language="en")
"""
# Registry of available backends
_backend_registry: dict[str, Type[TTSBackend]] = {
"chatterbox": ChatterboxBackend,
}
def __init__(self, config: Optional[EngineConfig] = None):
self.config = config or EngineConfig()
# Initialize components
self._backends: dict[str, TTSBackend] = {}
self._current_backend_name: str = self.config.default_backend
# Audio processor
self._processor = AudioProcessor(
AudioProcessingConfig(
music_volume_db=self.config.music_volume_db,
fade_in_ms=self.config.fade_in_ms,
fade_out_ms=self.config.fade_out_ms,
)
)
# Cache
self._cache = AudioCache(
CacheConfig(
enabled=self.config.enable_cache,
local_cache_dir=self.config.local_cache_dir,
hf_repo_id=self.config.hf_cache_repo,
)
)
@classmethod
def register_backend(cls, name: str, backend_class: Type[TTSBackend]) -> None:
"""Register a new backend class."""
cls._backend_registry[name] = backend_class
logger.info(f"Registered backend: {name}")
@classmethod
def available_backends(cls) -> list[str]:
"""List available backend names."""
return list(cls._backend_registry.keys())
def _get_backend(self, name: Optional[str] = None) -> TTSBackend:
"""Get or create a backend instance."""
name = name or self._current_backend_name
if name not in self._backend_registry:
available = ", ".join(self._backend_registry.keys())
raise ValueError(f"Unknown backend '{name}'. Available: {available}")
if name not in self._backends:
backend_config = BackendConfig(device=self.config.device)
self._backends[name] = self._backend_registry[name](backend_config)
return self._backends[name]
@property
def current_backend(self) -> TTSBackend:
"""Get the current active backend."""
return self._get_backend()
def set_backend(self, name: str) -> None:
"""Switch to a different backend."""
if name not in self._backend_registry:
available = ", ".join(self._backend_registry.keys())
raise ValueError(f"Unknown backend '{name}'. Available: {available}")
self._current_backend_name = name
logger.info(f"Switched to backend: {name}")
def load_backend(self, name: Optional[str] = None) -> None:
"""Pre-load a backend's model."""
backend = self._get_backend(name)
if not backend.is_loaded:
backend.load()
def unload_backend(self, name: Optional[str] = None) -> None:
"""Unload a backend's model to free memory."""
backend = self._get_backend(name)
if backend.is_loaded:
backend.unload()
def get_supported_languages(self, backend: Optional[str] = None) -> dict[str, str]:
"""Get supported languages for a backend."""
return self._get_backend(backend).supported_languages
def generate(
self,
text: str,
language: Optional[str] = None,
voice_audio: Optional[str] = None,
background_music: Optional[str] = None,
output_path: Optional[str] = None,
use_cache: bool = True,
split_sentences: bool = True,
max_chars_per_chunk: int = 250,
**kwargs,
) -> Union[bytes, str, tuple[int, np.ndarray]]:
"""
Generate a phone announcement.
Args:
text: Text to synthesize
language: Language code (default: "de")
voice_audio: Path/URL to reference audio for voice cloning
background_music: Name/path of background music file
output_path: Optional path to save output file
use_cache: Whether to use caching (default: True)
split_sentences: Auto-split long text into sentences (default: True)
max_chars_per_chunk: Max chars per chunk when splitting (default: 250)
**kwargs: Additional backend-specific parameters
Returns:
- If output_path: path to saved file
- If no output_path and no background_music: tuple(sample_rate, audio_array) for Gradio
- Otherwise: MP3 bytes
"""
language = language or self.config.default_language
backend = self.current_backend
# Generate voice ID for caching.
# - Voice cloning: derive from reference audio when available
# - If no reference audio: use "default"
if voice_audio:
voice_id = (
Path(voice_audio).stem
if os.path.exists(voice_audio or "")
else "custom"
)
else:
voice_id = "default"
# Check cache
if use_cache and self._cache.config.enabled:
cached = self._cache.get(text, voice_id, backend.name)
if cached:
logger.info("Using cached audio")
if output_path:
Path(output_path).write_bytes(cached)
return output_path
return cached
# Generate audio (use sentence splitting for long texts)
logger.info(f"Generating TTS: backend={backend.name}, lang={language}")
if split_sentences and len(text) > max_chars_per_chunk:
logger.info(f"Text is {len(text)} chars, splitting into sentences")
result = backend.generate_long(
text=text,
language=language,
voice_audio_path=voice_audio,
max_chars_per_chunk=max_chars_per_chunk,
**kwargs,
)
else:
result = backend.generate(
text=text, language=language, voice_audio_path=voice_audio, **kwargs
)
# Determine if we need post-processing
use_music = background_music or (
self.config.add_background_music and self.config.default_music
)
music_path = background_music or self.config.default_music
if use_music or output_path:
# Process audio with pydub
processed = self._processor.process(
audio=result.audio,
sample_rate=result.sample_rate,
output_path=output_path,
background_music_path=music_path if use_music else None,
)
# Cache if appropriate
if use_cache and isinstance(processed, bytes):
duration = len(result.audio) / result.sample_rate
self._cache.set(text, voice_id, backend.name, processed, duration)
return processed
else:
# Return raw audio for Gradio (sample_rate, audio_array)
return (result.sample_rate, result.audio)
def generate_raw(
self,
text: str,
language: Optional[str] = None,
voice_audio: Optional[str] = None,
split_sentences: bool = True,
max_chars_per_chunk: int = 250,
**kwargs,
) -> TTSResult:
"""
Generate raw audio without post-processing.
Args:
text: Text to synthesize
language: Language code (default from config)
voice_audio: Path/URL to reference audio for voice cloning
split_sentences: Auto-split long text into sentences (default: True)
max_chars_per_chunk: Max chars per chunk when splitting (default: 250)
**kwargs: Additional backend-specific parameters
Returns:
TTSResult with audio array and sample rate
"""
language = language or self.config.default_language
backend = self.current_backend
if split_sentences and len(text) > max_chars_per_chunk:
logger.info(f"Text is {len(text)} chars, splitting into sentences")
return backend.generate_long(
text=text,
language=language,
voice_audio_path=voice_audio,
max_chars_per_chunk=max_chars_per_chunk,
**kwargs,
)
else:
return backend.generate(
text=text, language=language, voice_audio_path=voice_audio, **kwargs
)
def list_background_music(self) -> list[str]:
"""List available background music files."""
return self._processor.list_available_music()
def clear_cache(self) -> int:
"""Clear the local audio cache. Returns number of files deleted."""
return self._cache.clear_local()
|