Spaces:
Sleeping
Sleeping
| """Chatterbox TTS provider implementation.""" | |
| import logging | |
| import numpy as np | |
| import soundfile as sf | |
| import io | |
| from typing import Iterator, Optional, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
| from ..base.tts_provider_base import TTSProviderBase | |
| from ...domain.exceptions import SpeechSynthesisException | |
| logger = logging.getLogger(__name__) | |
| # Flag to track Chatterbox availability | |
| CHATTERBOX_AVAILABLE = False | |
| # Try to import Chatterbox | |
| try: | |
| import torch | |
| import torchaudio as ta | |
| from chatterbox.tts import ChatterboxTTS | |
| CHATTERBOX_AVAILABLE = True | |
| logger.info("Chatterbox TTS engine is available") | |
| except ImportError as e: | |
| logger.warning(f"Chatterbox TTS engine is not available: {e}") | |
| except Exception as e: | |
| logger.error(f"Chatterbox import failed with unexpected error: {str(e)}") | |
| CHATTERBOX_AVAILABLE = False | |
| class ChatterboxTTSProvider(TTSProviderBase): | |
| """Chatterbox TTS provider implementation.""" | |
| def __init__(self, lang_code: str = 'en'): | |
| """Initialize the Chatterbox TTS provider.""" | |
| super().__init__( | |
| provider_name="Chatterbox", | |
| supported_languages=['en', 'zh'] # Chatterbox supports English and Chinese | |
| ) | |
| self.lang_code = lang_code | |
| self.model = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def _ensure_model(self): | |
| """Ensure the model is loaded.""" | |
| if self.model is None and CHATTERBOX_AVAILABLE: | |
| try: | |
| logger.info(f"Loading Chatterbox model on device: {self.device}") | |
| self.model = ChatterboxTTS.from_pretrained(device=self.device) | |
| logger.info("Chatterbox model successfully loaded") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Chatterbox model: {str(e)}") | |
| self.model = None | |
| return self.model is not None | |
| def is_available(self) -> bool: | |
| """Check if Chatterbox TTS is available.""" | |
| return CHATTERBOX_AVAILABLE and self._ensure_model() | |
| def get_available_voices(self) -> list[str]: | |
| """Get available voices for Chatterbox.""" | |
| # Chatterbox supports voice cloning with audio prompts | |
| # Default voice is the base model voice | |
| return ['default', 'custom'] | |
| def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
| """Generate audio using Chatterbox TTS.""" | |
| if not self.is_available(): | |
| raise SpeechSynthesisException("Chatterbox TTS engine is not available") | |
| try: | |
| # Extract parameters from request | |
| text = request.text_content.text | |
| voice = request.voice_settings.voice_id | |
| # Generate speech using Chatterbox | |
| if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'): | |
| # Use custom voice with audio prompt | |
| audio_prompt_path = request.voice_settings.audio_prompt_path | |
| wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) | |
| else: | |
| # Use default voice | |
| wav = self.model.generate(text) | |
| # Convert tensor to numpy array if needed | |
| if hasattr(wav, 'cpu'): | |
| wav = wav.cpu().numpy() | |
| elif hasattr(wav, 'detach'): | |
| wav = wav.detach().numpy() | |
| # Get sample rate from model | |
| sample_rate = self.model.sr | |
| # Convert numpy array to bytes | |
| audio_bytes = self._numpy_to_bytes(wav, sample_rate) | |
| return audio_bytes, sample_rate | |
| except Exception as e: | |
| self._handle_provider_error(e, "audio generation") | |
| def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
| """Generate audio stream using Chatterbox TTS.""" | |
| if not self.is_available(): | |
| raise SpeechSynthesisException("Chatterbox TTS engine is not available") | |
| try: | |
| # Chatterbox doesn't natively support streaming, so we'll generate the full audio | |
| # and split it into chunks for streaming | |
| text = request.text_content.text | |
| voice = request.voice_settings.voice_id | |
| # Generate full audio | |
| if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'): | |
| audio_prompt_path = request.voice_settings.audio_prompt_path | |
| wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) | |
| else: | |
| wav = self.model.generate(text) | |
| # Convert tensor to numpy array if needed | |
| if hasattr(wav, 'cpu'): | |
| wav = wav.cpu().numpy() | |
| elif hasattr(wav, 'detach'): | |
| wav = wav.detach().numpy() | |
| sample_rate = self.model.sr | |
| # Split audio into chunks for streaming | |
| chunk_size = int(sample_rate * 1.0) # 1 second chunks | |
| total_samples = len(wav) | |
| for start_idx in range(0, total_samples, chunk_size): | |
| end_idx = min(start_idx + chunk_size, total_samples) | |
| chunk = wav[start_idx:end_idx] | |
| # Convert chunk to bytes | |
| audio_bytes = self._numpy_to_bytes(chunk, sample_rate) | |
| # Check if this is the final chunk | |
| is_final = (end_idx >= total_samples) | |
| yield audio_bytes, sample_rate, is_final | |
| except Exception as e: | |
| self._handle_provider_error(e, "streaming audio generation") | |
| def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes: | |
| """Convert numpy audio array to bytes.""" | |
| try: | |
| # Ensure audio is in the right format | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) | |
| # Normalize if needed | |
| if np.max(np.abs(audio_array)) > 1.0: | |
| audio_array = audio_array / np.max(np.abs(audio_array)) | |
| # Create an in-memory buffer | |
| buffer = io.BytesIO() | |
| # Write audio data to buffer as WAV | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| # Get bytes from buffer | |
| buffer.seek(0) | |
| return buffer.read() | |
| except Exception as e: | |
| raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e | |
| def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]: | |
| """ | |
| Generate audio with a custom voice prompt. | |
| Args: | |
| text: Text to synthesize | |
| audio_prompt_path: Path to audio file for voice cloning | |
| Returns: | |
| tuple: (audio_bytes, sample_rate) | |
| """ | |
| if not self.is_available(): | |
| raise SpeechSynthesisException("Chatterbox TTS engine is not available") | |
| try: | |
| wav = self.model.generate(text, audio_prompt_path=audio_prompt_path) | |
| # Convert tensor to numpy array if needed | |
| if hasattr(wav, 'cpu'): | |
| wav = wav.cpu().numpy() | |
| elif hasattr(wav, 'detach'): | |
| wav = wav.detach().numpy() | |
| sample_rate = self.model.sr | |
| audio_bytes = self._numpy_to_bytes(wav, sample_rate) | |
| return audio_bytes, sample_rate | |
| except Exception as e: | |
| self._handle_provider_error(e, "voice prompt audio generation") |