TTS / tts_runner /engines /chatterbox.py
github-actions[bot]
Auto-deploy from GitHub: 7d7ea68deaabe7026fbf5738ffdfbe29ae067235
aa3ab7e
from pathlib import Path
from typing import List
import spacy
import torchaudio as ta
import torch
from ..base import BaseTTS
class ChatterboxTTSProcessor(BaseTTS):
"""Text-to-Speech processor using ChatterboxTTS."""
def __init__(self, stream_audio=False):
super().__init__("Chatterbox", stream_audio=stream_audio)
print("Initializing Chatterbox...")
from chatterbox.tts import ChatterboxTTS
print("Loading Modal...")
self.model = ChatterboxTTS.from_pretrained(device=self.device)
self.nlp=None
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
from spacy.cli import download
download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")
print("Model loaded successfully")
def tokenize_sentences(self, text):
"""Split text into sentences using spaCy.
Args:
text: Input text to tokenize
Returns:
List of sentence strings
"""
doc = self.nlp(text)
return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
def norm_and_token_count(self, text):
"""Get normalized text and token count.
Args:
text: Input text to normalize and count tokens
Returns:
Tuple of (normalized_text, token_count)
"""
from chatterbox.tts import punc_norm
with torch.inference_mode():
normalized = punc_norm(text)
tokens = self.model.tokenizer.text_to_tokens(normalized)
token_count = tokens.shape[1]
# Clear tokens from GPU memory immediately
if hasattr(tokens, 'cpu'):
tokens = tokens.cpu()
del tokens
return normalized, token_count
def split_sentences(self, text, max_tokens=200):
"""Split text into chunks based on token count.
Args:
text: Input text to split
max_tokens: Maximum tokens per chunk
Returns:
List of text chunks
"""
sentences = self.tokenize_sentences(text)
chunks = []
current = ""
for sentence in sentences:
# Check if sentence alone exceeds max tokens
_, sentence_tokens = self.norm_and_token_count(sentence)
if sentence_tokens > max_tokens:
# If current chunk has content, save it first
if current:
chunks.append(current.strip())
current = ""
# Split long sentence by words if it's too long
words = sentence.split()
temp_chunk = ""
for word in words:
test_chunk = (temp_chunk + " " + word).strip() if temp_chunk else word
_, test_tokens = self.norm_and_token_count(test_chunk)
if test_tokens <= max_tokens:
temp_chunk = test_chunk
else:
if temp_chunk:
chunks.append(temp_chunk.strip())
temp_chunk = word
if temp_chunk:
chunks.append(temp_chunk.strip())
current = ""
continue
# Try adding sentence to current chunk
candidate = (current + " " + sentence).strip() if current else sentence.strip()
_, token_count = self.norm_and_token_count(candidate)
if token_count <= max_tokens:
current = candidate
else:
# Current chunk is full, save it and start new one
if current:
chunks.append(current.strip())
current = sentence.strip()
# Don't forget the last chunk
if current:
chunks.append(current.strip())
return chunks
def generate_chunk_audio_file(self, sentence: str, chunk_index: int, voice: str, speed: float) -> Path:
wav = self.model.generate(
sentence,
audio_prompt_path=voice,
temperature=speed
)
# Save sentence to numbered file
chunk_file = self.temp_output_dir / f"chunk_{chunk_index:04d}.wav"
ta.save(str(chunk_file), wav, self.model.sr)
del wav
if self.stream_audio:
self.queue_audio_for_streaming(str(chunk_file))
return chunk_file
def generate_audio_files(self, text: str, voice: str, speed: float, chunk_id: int = None):
sentences = self.split_sentences(text)
audio_files = []
total_sentences = len(sentences)
print(f"Processing {total_sentences} text sentences...")
with torch.inference_mode():
for i, sentence in enumerate(sentences):
if self.save_audio_file:
chunk_file = self.generate_chunk_audio_file(sentence, chunk_id if chunk_id else i, voice, speed)
audio_files.append(chunk_file)
print(f"Sentence {i + 1}/{total_sentences} processed -> {chunk_file.name} -> {sentence}")
return audio_files