File size: 4,241 Bytes
68a99fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa3ab7e
 
68a99fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from pathlib import Path
from typing import List
import spacy
import torchaudio as ta
import torch
from ..base import BaseTTS

class ChatterboxTTSProcessor(BaseTTS):
	"""Text-to-Speech processor using ChatterboxTTS."""
	
	def __init__(self, stream_audio=False):
		super().__init__("Chatterbox", stream_audio=stream_audio)
		print("Initializing Chatterbox...")
		from chatterbox.tts import ChatterboxTTS
		print("Loading Modal...")
		self.model = ChatterboxTTS.from_pretrained(device=self.device)

		self.nlp=None
		try:
			self.nlp = spacy.load("en_core_web_sm")
		except OSError:
			from spacy.cli import download
			download("en_core_web_sm")
			self.nlp = spacy.load("en_core_web_sm")
		print("Model loaded successfully")

	def tokenize_sentences(self, text):
		"""Split text into sentences using spaCy.
		
		Args:
			text: Input text to tokenize
			
		Returns:
			List of sentence strings
		"""
		doc = self.nlp(text)
		return [sent.text.strip() for sent in doc.sents if sent.text.strip()]

	def norm_and_token_count(self, text):
		"""Get normalized text and token count.
		
		Args:
			text: Input text to normalize and count tokens
			
		Returns:
			Tuple of (normalized_text, token_count)
		"""
		from chatterbox.tts import punc_norm
		with torch.inference_mode():
			normalized = punc_norm(text)
			tokens = self.model.tokenizer.text_to_tokens(normalized)
			token_count = tokens.shape[1]

			# Clear tokens from GPU memory immediately
			if hasattr(tokens, 'cpu'):
				tokens = tokens.cpu()
			del tokens
			return normalized, token_count

	def split_sentences(self, text, max_tokens=200):
		"""Split text into chunks based on token count.
		
		Args:
			text: Input text to split
			max_tokens: Maximum tokens per chunk
			
		Returns:
			List of text chunks
		"""
		sentences = self.tokenize_sentences(text)
		chunks = []
		current = ""

		for sentence in sentences:
			# Check if sentence alone exceeds max tokens
			_, sentence_tokens = self.norm_and_token_count(sentence)
			if sentence_tokens > max_tokens:
				# If current chunk has content, save it first
				if current:
					chunks.append(current.strip())
					current = ""
				
				# Split long sentence by words if it's too long
				words = sentence.split()
				temp_chunk = ""
				
				for word in words:
					test_chunk = (temp_chunk + " " + word).strip() if temp_chunk else word
					_, test_tokens = self.norm_and_token_count(test_chunk)
					
					if test_tokens <= max_tokens:
						temp_chunk = test_chunk
					else:
						if temp_chunk:
							chunks.append(temp_chunk.strip())
						temp_chunk = word
				
				if temp_chunk:
					chunks.append(temp_chunk.strip())
				current = ""
				continue
			
			# Try adding sentence to current chunk
			candidate = (current + " " + sentence).strip() if current else sentence.strip()
			_, token_count = self.norm_and_token_count(candidate)
			
			if token_count <= max_tokens:
				current = candidate
			else:
				# Current chunk is full, save it and start new one
				if current:
					chunks.append(current.strip())
				current = sentence.strip()
		
		# Don't forget the last chunk
		if current:
			chunks.append(current.strip())
			
		return chunks

	def generate_chunk_audio_file(self, sentence: str, chunk_index: int, voice: str, speed: float) -> Path:
		wav = self.model.generate(
			sentence,
			audio_prompt_path=voice,
			temperature=speed
		)
		
		# Save sentence to numbered file
		chunk_file = self.temp_output_dir / f"chunk_{chunk_index:04d}.wav"
		ta.save(str(chunk_file), wav, self.model.sr)
		del wav
		
		if self.stream_audio:
			self.queue_audio_for_streaming(str(chunk_file))
		return chunk_file

	def generate_audio_files(self, text: str, voice: str, speed: float, chunk_id: int = None):
		sentences = self.split_sentences(text)
		audio_files = []
		total_sentences = len(sentences)
		
		print(f"Processing {total_sentences} text sentences...")
		with torch.inference_mode():
			for i, sentence in enumerate(sentences):
				if self.save_audio_file:
					chunk_file = self.generate_chunk_audio_file(sentence, chunk_id if chunk_id else i, voice, speed)
					audio_files.append(chunk_file)
				print(f"Sentence {i + 1}/{total_sentences} processed -> {chunk_file.name} -> {sentence}")
				
		return audio_files