File size: 4,241 Bytes
68a99fc aa3ab7e 68a99fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from pathlib import Path
from typing import List
import spacy
import torchaudio as ta
import torch
from ..base import BaseTTS
class ChatterboxTTSProcessor(BaseTTS):
"""Text-to-Speech processor using ChatterboxTTS."""
def __init__(self, stream_audio=False):
super().__init__("Chatterbox", stream_audio=stream_audio)
print("Initializing Chatterbox...")
from chatterbox.tts import ChatterboxTTS
print("Loading Modal...")
self.model = ChatterboxTTS.from_pretrained(device=self.device)
self.nlp=None
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
from spacy.cli import download
download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")
print("Model loaded successfully")
def tokenize_sentences(self, text):
"""Split text into sentences using spaCy.
Args:
text: Input text to tokenize
Returns:
List of sentence strings
"""
doc = self.nlp(text)
return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
def norm_and_token_count(self, text):
"""Get normalized text and token count.
Args:
text: Input text to normalize and count tokens
Returns:
Tuple of (normalized_text, token_count)
"""
from chatterbox.tts import punc_norm
with torch.inference_mode():
normalized = punc_norm(text)
tokens = self.model.tokenizer.text_to_tokens(normalized)
token_count = tokens.shape[1]
# Clear tokens from GPU memory immediately
if hasattr(tokens, 'cpu'):
tokens = tokens.cpu()
del tokens
return normalized, token_count
def split_sentences(self, text, max_tokens=200):
"""Split text into chunks based on token count.
Args:
text: Input text to split
max_tokens: Maximum tokens per chunk
Returns:
List of text chunks
"""
sentences = self.tokenize_sentences(text)
chunks = []
current = ""
for sentence in sentences:
# Check if sentence alone exceeds max tokens
_, sentence_tokens = self.norm_and_token_count(sentence)
if sentence_tokens > max_tokens:
# If current chunk has content, save it first
if current:
chunks.append(current.strip())
current = ""
# Split long sentence by words if it's too long
words = sentence.split()
temp_chunk = ""
for word in words:
test_chunk = (temp_chunk + " " + word).strip() if temp_chunk else word
_, test_tokens = self.norm_and_token_count(test_chunk)
if test_tokens <= max_tokens:
temp_chunk = test_chunk
else:
if temp_chunk:
chunks.append(temp_chunk.strip())
temp_chunk = word
if temp_chunk:
chunks.append(temp_chunk.strip())
current = ""
continue
# Try adding sentence to current chunk
candidate = (current + " " + sentence).strip() if current else sentence.strip()
_, token_count = self.norm_and_token_count(candidate)
if token_count <= max_tokens:
current = candidate
else:
# Current chunk is full, save it and start new one
if current:
chunks.append(current.strip())
current = sentence.strip()
# Don't forget the last chunk
if current:
chunks.append(current.strip())
return chunks
def generate_chunk_audio_file(self, sentence: str, chunk_index: int, voice: str, speed: float) -> Path:
wav = self.model.generate(
sentence,
audio_prompt_path=voice,
temperature=speed
)
# Save sentence to numbered file
chunk_file = self.temp_output_dir / f"chunk_{chunk_index:04d}.wav"
ta.save(str(chunk_file), wav, self.model.sr)
del wav
if self.stream_audio:
self.queue_audio_for_streaming(str(chunk_file))
return chunk_file
def generate_audio_files(self, text: str, voice: str, speed: float, chunk_id: int = None):
sentences = self.split_sentences(text)
audio_files = []
total_sentences = len(sentences)
print(f"Processing {total_sentences} text sentences...")
with torch.inference_mode():
for i, sentence in enumerate(sentences):
if self.save_audio_file:
chunk_file = self.generate_chunk_audio_file(sentence, chunk_id if chunk_id else i, voice, speed)
audio_files.append(chunk_file)
print(f"Sentence {i + 1}/{total_sentences} processed -> {chunk_file.name} -> {sentence}")
return audio_files |