import os import random import re import tempfile import torch import torchaudio import numpy as np from chatterbox.tts import ChatterboxTTS # Constants MAX_CHUNK_CHARS = 250 DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class VoiceCloningEngine: """ A standalone engine to handle Chatterbox TTS operations including model management, text chunking, and audio generation. """ def __init__(self, device=DEFAULT_DEVICE): self.device = device self.model = None self.sr = 24000 # Default sample rate for Chatterbox def load_model(self): """Lazy load the model to save memory until needed.""" if self.model is None: print(f"Initializing Chatterbox TTS on {self.device}...") try: self.model = ChatterboxTTS.from_pretrained(self.device) self.sr = self.model.sr except Exception as e: print(f"Failed to load model: {e}") raise RuntimeError(f"Model initialization failed: {str(e)}") return self.model def set_seed(self, seed: int): """Set seeds for reproducibility.""" if seed == 0: seed = random.randint(1, 1000000) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) return seed def chunk_text(self, text): """ Split long scripts into chunks at sentence boundaries. Optimized for the Chatterbox model's token limit. """ if not text: return [] # Split by sentence boundaries while keeping the punctuation sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS: current_chunk += (sentence + " ") else: if current_chunk: chunks.append(current_chunk.strip()) # If a single sentence is too long, split it by commas or spaces if len(sentence) > MAX_CHUNK_CHARS: sub_parts = re.split(r'(?<=,)\s+|\s+', sentence) temp = "" for part in sub_parts: if len(temp) + len(part) <= MAX_CHUNK_CHARS: temp += (part + " ") else: if temp: chunks.append(temp.strip()) temp = part + " " current_chunk = temp else: current_chunk = sentence + " " if current_chunk: chunks.append(current_chunk.strip()) return chunks def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress_callback=None): """ Generate cloned audio by processing chunks and concatenating them. """ self.load_model() actual_seed = self.set_seed(int(seed)) chunks = self.chunk_text(text) if not chunks: raise ValueError("The script is empty or invalid.") if ref_audio is None: raise ValueError("A reference audio file is required for voice cloning.") all_wavs = [] total = len(chunks) for i, chunk in enumerate(chunks): if progress_callback: progress_callback((i / total), desc=f"Processing chunk {i+1}/{total}") # Generate the audio chunk wav = self.model.generate( chunk, audio_prompt_path=ref_audio, exaggeration=exaggeration, temperature=temperature, cfg_weight=cfg_weight ) # Ensure the output is a 2D tensor [1, T] if wav.dim() == 1: wav = wav.unsqueeze(0) all_wavs.append(wav.cpu()) # Concatenate all segments final_wav = torch.cat(all_wavs, dim=-1) # Save to a temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: output_path = tmp.name torchaudio.save(output_path, final_wav, self.sr) return output_path, actual_seed