Spaces:
Running
Running
| import os | |
| import random | |
| import re | |
| import tempfile | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from chatterbox.tts import ChatterboxTTS | |
| # Constants | |
| MAX_CHUNK_CHARS = 250 | |
| DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| class VoiceCloningEngine: | |
| """ | |
| A standalone engine to handle Chatterbox TTS operations including | |
| model management, text chunking, and audio generation. | |
| """ | |
| def __init__(self, device=DEFAULT_DEVICE): | |
| self.device = device | |
| self.model = None | |
| self.sr = 24000 # Default sample rate for Chatterbox | |
| def load_model(self): | |
| """Lazy load the model to save memory until needed.""" | |
| if self.model is None: | |
| print(f"Initializing Chatterbox TTS on {self.device}...") | |
| try: | |
| self.model = ChatterboxTTS.from_pretrained(self.device) | |
| self.sr = self.model.sr | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| raise RuntimeError(f"Model initialization failed: {str(e)}") | |
| return self.model | |
| def set_seed(self, seed: int): | |
| """Set seeds for reproducibility.""" | |
| if seed == 0: | |
| seed = random.randint(1, 1000000) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| return seed | |
| def chunk_text(self, text): | |
| """ | |
| Split long scripts into chunks at sentence boundaries. | |
| Optimized for the Chatterbox model's token limit. | |
| """ | |
| if not text: | |
| return [] | |
| # Split by sentence boundaries while keeping the punctuation | |
| sentences = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS: | |
| current_chunk += (sentence + " ") | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| # If a single sentence is too long, split it by commas or spaces | |
| if len(sentence) > MAX_CHUNK_CHARS: | |
| sub_parts = re.split(r'(?<=,)\s+|\s+', sentence) | |
| temp = "" | |
| for part in sub_parts: | |
| if len(temp) + len(part) <= MAX_CHUNK_CHARS: | |
| temp += (part + " ") | |
| else: | |
| if temp: chunks.append(temp.strip()) | |
| temp = part + " " | |
| current_chunk = temp | |
| else: | |
| current_chunk = sentence + " " | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress_callback=None): | |
| """ | |
| Generate cloned audio by processing chunks and concatenating them. | |
| """ | |
| self.load_model() | |
| actual_seed = self.set_seed(int(seed)) | |
| chunks = self.chunk_text(text) | |
| if not chunks: | |
| raise ValueError("The script is empty or invalid.") | |
| if ref_audio is None: | |
| raise ValueError("A reference audio file is required for voice cloning.") | |
| all_wavs = [] | |
| total = len(chunks) | |
| for i, chunk in enumerate(chunks): | |
| if progress_callback: | |
| progress_callback((i / total), desc=f"Processing chunk {i+1}/{total}") | |
| # Generate the audio chunk | |
| wav = self.model.generate( | |
| chunk, | |
| audio_prompt_path=ref_audio, | |
| exaggeration=exaggeration, | |
| temperature=temperature, | |
| cfg_weight=cfg_weight | |
| ) | |
| # Ensure the output is a 2D tensor [1, T] | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| all_wavs.append(wav.cpu()) | |
| # Concatenate all segments | |
| final_wav = torch.cat(all_wavs, dim=-1) | |
| # Save to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| output_path = tmp.name | |
| torchaudio.save(output_path, final_wav, self.sr) | |
| return output_path, actual_seed | |