Chatterbox_tts / engine.py
codewithjarair's picture
Update engine.py
6d30bec verified
import os
import random
import re
import tempfile
import torch
import torchaudio
import numpy as np
from chatterbox.tts import ChatterboxTTS
# Constants
MAX_CHUNK_CHARS = 250
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class VoiceCloningEngine:
"""
A standalone engine to handle Chatterbox TTS operations including
model management, text chunking, and audio generation.
"""
def __init__(self, device=DEFAULT_DEVICE):
self.device = device
self.model = None
self.sr = 24000 # Default sample rate for Chatterbox
def load_model(self):
"""Lazy load the model to save memory until needed."""
if self.model is None:
print(f"Initializing Chatterbox TTS on {self.device}...")
try:
self.model = ChatterboxTTS.from_pretrained(self.device)
self.sr = self.model.sr
except Exception as e:
print(f"Failed to load model: {e}")
raise RuntimeError(f"Model initialization failed: {str(e)}")
return self.model
def set_seed(self, seed: int):
"""Set seeds for reproducibility."""
if seed == 0:
seed = random.randint(1, 1000000)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
return seed
def chunk_text(self, text):
"""
Split long scripts into chunks at sentence boundaries.
Optimized for the Chatterbox model's token limit.
"""
if not text:
return []
# Split by sentence boundaries while keeping the punctuation
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS:
current_chunk += (sentence + " ")
else:
if current_chunk:
chunks.append(current_chunk.strip())
# If a single sentence is too long, split it by commas or spaces
if len(sentence) > MAX_CHUNK_CHARS:
sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
temp = ""
for part in sub_parts:
if len(temp) + len(part) <= MAX_CHUNK_CHARS:
temp += (part + " ")
else:
if temp: chunks.append(temp.strip())
temp = part + " "
current_chunk = temp
else:
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress_callback=None):
"""
Generate cloned audio by processing chunks and concatenating them.
"""
self.load_model()
actual_seed = self.set_seed(int(seed))
chunks = self.chunk_text(text)
if not chunks:
raise ValueError("The script is empty or invalid.")
if ref_audio is None:
raise ValueError("A reference audio file is required for voice cloning.")
all_wavs = []
total = len(chunks)
for i, chunk in enumerate(chunks):
if progress_callback:
progress_callback((i / total), desc=f"Processing chunk {i+1}/{total}")
# Generate the audio chunk
wav = self.model.generate(
chunk,
audio_prompt_path=ref_audio,
exaggeration=exaggeration,
temperature=temperature,
cfg_weight=cfg_weight
)
# Ensure the output is a 2D tensor [1, T]
if wav.dim() == 1:
wav = wav.unsqueeze(0)
all_wavs.append(wav.cpu())
# Concatenate all segments
final_wav = torch.cat(all_wavs, dim=-1)
# Save to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name
torchaudio.save(output_path, final_wav, self.sr)
return output_path, actual_seed