tts_gallery / src /models /tts /chatterbox_model.py
Michael Hu
refactor: replace inline model definitions with ModelFactory and remove unused imports
ef4db28
import torch
import torchaudio as ta
import tempfile
import os
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
from ..base import TTSModel
class ChatterboxTTSModel(TTSModel):
"""Chatterbox multilingual TTS model implementation"""
def __init__(self):
self._model = None
self._initialized = False
@property
def name(self):
return "ResembleAI/chatterbox"
@property
def description(self):
return "Industrial-grade TTS solution with multilingual support"
def initialize(self):
"""Initialize the Chatterbox model"""
if self._initialized:
return True
try:
self._model = ChatterboxMultilingualTTS.from_pretrained(
device="cuda" if torch.cuda.is_available() else "cpu"
)
self._initialized = True
return True
except RuntimeError as e:
if "Attempting to deserialize object on a CUDA device" in str(e):
print("CUDA model detected but CUDA is not available. Loading model on CPU...")
self._model = ChatterboxMultilingualTTS.from_pretrained(device="cpu")
self._initialized = True
return True
else:
print(f"Error initializing Chatterbox model: {e}")
return False
def generate_speech(self, text, language="English", audio_prompt=None, **kwargs):
"""
Generate speech from text using Chatterbox multilingual TTS
Args:
text (str): Text to convert to speech
language (str): Language name ('English' or 'Chinese')
audio_prompt (str, optional): Path to reference audio file for voice cloning
**kwargs: Additional parameters for generation
Returns:
str: Path to the generated audio file
"""
if not self._initialized:
if not self.initialize():
raise RuntimeError("Failed to initialize Chatterbox model")
# Map language names to language codes
language_map = {
"English": "en",
"Chinese": "zh"
}
language_id = language_map.get(language, "en")
# Default generation parameters
generate_kwargs = {
"exaggeration": 0.5,
"temperature": 0.8,
"cfg_weight": 0.3,
}
# Update with any user-provided kwargs
generate_kwargs.update(kwargs)
# Generate speech using Chatterbox
if audio_prompt and os.path.exists(audio_prompt):
# Use audio prompt for voice cloning
wav = self._model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs)
else:
# Generate without audio prompt (default voice)
wav = self._model.generate(text, language_id=language_id, **generate_kwargs)
# Save to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
ta.save(tmp_file.name, wav, self._model.sr)
return tmp_file.name
def supports_voice_cloning(self):
return True
def supports_multilingual(self):
return True
def get_supported_languages(self):
return ["English", "Chinese"]