Spaces:
Sleeping
Sleeping
Michael Hu
refactor: replace inline model definitions with ModelFactory and remove unused imports
ef4db28
| import torch | |
| import torchaudio as ta | |
| import tempfile | |
| import os | |
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS | |
| from ..base import TTSModel | |
| class ChatterboxTTSModel(TTSModel): | |
| """Chatterbox multilingual TTS model implementation""" | |
| def __init__(self): | |
| self._model = None | |
| self._initialized = False | |
| def name(self): | |
| return "ResembleAI/chatterbox" | |
| def description(self): | |
| return "Industrial-grade TTS solution with multilingual support" | |
| def initialize(self): | |
| """Initialize the Chatterbox model""" | |
| if self._initialized: | |
| return True | |
| try: | |
| self._model = ChatterboxMultilingualTTS.from_pretrained( | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self._initialized = True | |
| return True | |
| except RuntimeError as e: | |
| if "Attempting to deserialize object on a CUDA device" in str(e): | |
| print("CUDA model detected but CUDA is not available. Loading model on CPU...") | |
| self._model = ChatterboxMultilingualTTS.from_pretrained(device="cpu") | |
| self._initialized = True | |
| return True | |
| else: | |
| print(f"Error initializing Chatterbox model: {e}") | |
| return False | |
| def generate_speech(self, text, language="English", audio_prompt=None, **kwargs): | |
| """ | |
| Generate speech from text using Chatterbox multilingual TTS | |
| Args: | |
| text (str): Text to convert to speech | |
| language (str): Language name ('English' or 'Chinese') | |
| audio_prompt (str, optional): Path to reference audio file for voice cloning | |
| **kwargs: Additional parameters for generation | |
| Returns: | |
| str: Path to the generated audio file | |
| """ | |
| if not self._initialized: | |
| if not self.initialize(): | |
| raise RuntimeError("Failed to initialize Chatterbox model") | |
| # Map language names to language codes | |
| language_map = { | |
| "English": "en", | |
| "Chinese": "zh" | |
| } | |
| language_id = language_map.get(language, "en") | |
| # Default generation parameters | |
| generate_kwargs = { | |
| "exaggeration": 0.5, | |
| "temperature": 0.8, | |
| "cfg_weight": 0.3, | |
| } | |
| # Update with any user-provided kwargs | |
| generate_kwargs.update(kwargs) | |
| # Generate speech using Chatterbox | |
| if audio_prompt and os.path.exists(audio_prompt): | |
| # Use audio prompt for voice cloning | |
| wav = self._model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs) | |
| else: | |
| # Generate without audio prompt (default voice) | |
| wav = self._model.generate(text, language_id=language_id, **generate_kwargs) | |
| # Save to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| ta.save(tmp_file.name, wav, self._model.sr) | |
| return tmp_file.name | |
| def supports_voice_cloning(self): | |
| return True | |
| def supports_multilingual(self): | |
| return True | |
| def get_supported_languages(self): | |
| return ["English", "Chinese"] |