codewithjarair commited on
Commit
c26c2ec
·
verified ·
1 Parent(s): 8b9660d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -202
app.py CHANGED
@@ -1,110 +1,99 @@
1
  import os
2
  import random
3
- import numpy as np
 
4
  import torch
5
  import torchaudio
 
6
  import gradio as gr
7
- import re
8
- import tempfile
9
  from chatterbox.tts import ChatterboxTTS
10
 
11
- # Set device
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- def set_seed(seed: int):
15
- """Set random seed for reproducibility."""
16
- if seed == 0:
17
- seed = random.randint(1, 1000000)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed(seed)
20
- torch.cuda.manual_seed_all(seed)
21
- random.seed(seed)
22
- np.random.seed(seed)
23
- return seed
24
 
25
- def split_text(text, max_chars=250):
26
  """
27
- Intelligent text chunking with sentence boundary detection.
28
- Splits text into chunks of approximately max_chars, trying to stay on sentence boundaries.
29
  """
30
- # Simple sentence boundary detection using regex
31
- # Split by periods, question marks, and exclamation marks followed by whitespace
32
- sentences = re.split(r'(?<=[.!?])\s+', text.strip())
33
- chunks = []
34
- current_chunk = ""
35
-
36
- for sentence in sentences:
37
- if len(current_chunk) + len(sentence) <= max_chars:
38
- current_chunk += (sentence + " ")
39
- else:
40
- if current_chunk:
41
- chunks.append(current_chunk.strip())
42
- # If a single sentence is longer than max_chars, we have to split it
43
- if len(sentence) > max_chars:
44
- # Further split long sentences by commas or spaces as fallback
45
- sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
46
- temp_chunk = ""
47
- for part in sub_parts:
48
- if len(temp_chunk) + len(part) <= max_chars:
49
- temp_chunk += (part + " ")
50
- else:
51
- if temp_chunk:
52
- chunks.append(temp_chunk.strip())
53
- temp_chunk = part + " "
54
- current_chunk = temp_chunk
55
- else:
56
- current_chunk = sentence + " "
57
-
58
- if current_chunk:
59
- chunks.append(current_chunk.strip())
60
-
61
- return chunks
62
 
63
- def load_model():
64
- """Load the Chatterbox TTS model."""
65
- try:
66
- print(f"Loading Chatterbox TTS model on {DEVICE}...")
67
- model = ChatterboxTTS.from_pretrained(DEVICE)
68
- return model
69
- except Exception as e:
70
- print(f"Error loading model: {e}")
71
- return None
72
 
73
- def generate_tts(model, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=gr.Progress()):
74
- """
75
- Generate TTS audio from text, handling long scripts via chunking.
76
- """
77
- if model is None:
78
- # Try to load if not already loaded (for HF Spaces persistence)
79
- model = load_model()
80
- if model is None:
81
- return None, "Error: Model could not be loaded. Check your environment/GPU."
82
-
83
- if not text.strip():
84
- return None, "Error: Please enter some text."
85
-
86
- if ref_audio is None:
87
- return None, "Error: Please upload a reference audio file for voice cloning."
 
 
 
 
88
 
89
- # Set seed
90
- actual_seed = set_seed(int(seed))
91
-
92
- # Chunk the text
93
- chunks = split_text(text)
94
- total_chunks = len(chunks)
95
-
96
- if total_chunks == 0:
97
- return None, "Error: No valid text to process."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- all_wavs = []
100
-
101
- try:
 
 
 
102
  for i, chunk in enumerate(chunks):
103
- progress((i / total_chunks), desc=f"Processing chunk {i+1}/{total_chunks}")
 
104
 
105
- # Generate audio for this chunk
106
- # Chatterbox.generate expects: text, audio_prompt_path, exaggeration, temperature, cfg_weight, etc.
107
- wav = model.generate(
108
  chunk,
109
  audio_prompt_path=ref_audio,
110
  exaggeration=exaggeration,
@@ -112,125 +101,60 @@ def generate_tts(model, text, ref_audio, exaggeration, cfg_weight, temperature,
112
  cfg_weight=cfg_weight
113
  )
114
 
115
- # wav is usually a torch tensor [1, T] or [T]
116
  if wav.dim() == 1:
117
  wav = wav.unsqueeze(0)
118
-
119
  all_wavs.append(wav.cpu())
120
-
121
- # Concatenate all audio chunks along the time dimension (last dim)
122
- if not all_wavs:
123
- return None, "Error: No audio was generated."
124
-
125
  final_wav = torch.cat(all_wavs, dim=-1)
126
 
127
- # Save to a temporary file
128
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
129
- output_path = tmp_file.name
130
- torchaudio.save(output_path, final_wav, model.sr)
131
-
132
- return output_path, f"Successfully generated audio with seed {actual_seed}. Total chunks: {total_chunks}."
133
-
 
 
 
 
 
 
134
  except Exception as e:
135
- import traceback
136
- traceback.print_exc()
137
- return None, f"Error during generation: {str(e)}"
138
 
139
- # Define the Gradio Interface
140
- def create_ui():
141
- # Model is loaded once and stored in state
142
- model_state = gr.State(None)
143
 
144
- with gr.Blocks(theme=gr.themes.Soft(), title="Chatterbox Voice Clone TTS") as demo:
145
- gr.Markdown("# 🗣️ Voice Cloning TTS Chatterbox")
146
- gr.Markdown("""
147
- Clone any voice using a short reference audio clip. This application is optimized for long scripts
148
- through intelligent sentence-based chunking and sequential processing.
149
- """)
150
-
151
- with gr.Row():
152
- with gr.Column(scale=1):
153
- text_input = gr.Textbox(
154
- label="Script",
155
- placeholder="Enter your long script here. The app will automatically handle chunking...",
156
- lines=10,
157
- value="Welcome to the Chatterbox voice cloning application. This tool allows you to generate high-quality speech from long scripts by automatically splitting them into manageable segments. Simply upload a reference audio clip of the voice you want to clone, and adjust the parameters to your liking."
158
- )
159
- ref_audio = gr.Audio(
160
- label="Reference Audio (Voice to Clone)",
161
- type="filepath",
162
- sources=["upload", "microphone"]
163
- )
164
-
165
- with gr.Row():
166
- exaggeration = gr.Slider(
167
- 0.1, 1.0, value=0.5, step=0.05,
168
- label="Exaggeration",
169
- info="Default 0.5. Extreme values (>0.8) may be unstable."
170
- )
171
- cfg_weight = gr.Slider(
172
- 0.0, 1.0, value=0.5, step=0.05,
173
- label="CFG/Pace",
174
- info="Control the pace and guidance scale."
175
- )
176
-
177
- with gr.Accordion("Advanced Options", open=False):
178
- seed = gr.Number(
179
- label="Seed",
180
- value=0,
181
- precision=0,
182
- info="Set to 0 for random seed each time."
183
- )
184
- temperature = gr.Slider(
185
- 0.1, 2.0, value=1.0, step=0.05,
186
- label="Temperature",
187
- info="Higher values increase randomness and expressiveness."
188
- )
189
-
190
- generate_btn = gr.Button("Generate Audio", variant="primary")
191
-
192
- with gr.Column(scale=1):
193
- audio_output = gr.Audio(label="Generated Speech", type="filepath")
194
- status_msg = gr.Textbox(label="Status", interactive=False)
195
-
196
- gr.Markdown("### 📖 Documentation")
197
- gr.Markdown("""
198
- ### Features
199
- - **Voice Cloning**: Provide a clear 5-10 second reference clip.
200
- - **Intelligent Chunking**: Scripts are split at sentence boundaries (approx. 250 chars) to ensure smooth transitions and avoid memory issues.
201
- - **Sequential Processing**: Audio chunks are generated one-by-one and concatenated for long-form content.
202
- - **Parameter Control**:
203
- - **Exaggeration**: Intensity of cloned voice traits.
204
- - **CFG/Pace**: Balance between text adherence and reference voice speed.
205
- - **Temperature**: Randomness of the output.
206
-
207
- ### Tips
208
- - Use a high-quality, noise-free reference audio for best results.
209
- - For dramatic speech, try higher **Exaggeration** and lower **CFG**.
210
- - If the output sounds unnatural, try a different **Seed** or adjust **Temperature**.
211
- """)
212
 
213
- # Event handling
214
- generate_btn.click(
215
- fn=generate_tts,
216
- inputs=[
217
- model_state,
218
- text_input,
219
- ref_audio,
220
- exaggeration,
221
- cfg_weight,
222
- temperature,
223
- seed
224
- ],
225
- outputs=[audio_output, status_msg]
226
- )
227
-
228
- # Load model on startup
229
- demo.load(fn=load_model, outputs=model_state)
230
-
231
- return demo
232
 
233
  if __name__ == "__main__":
234
- ui = create_ui()
235
- # Use server_name="0.0.0.0" for deployment compatibility
236
- ui.launch(server_name="0.0.0.0")
 
1
  import os
2
  import random
3
+ import re
4
+ import tempfile
5
  import torch
6
  import torchaudio
7
+ import numpy as np
8
  import gradio as gr
 
 
9
  from chatterbox.tts import ChatterboxTTS
10
 
11
+ # Constants
12
+ MAX_CHUNK_CHARS = 250
13
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
14
 
15
+ class VoiceCloningEngine:
16
  """
17
+ A dedicated engine to handle Chatterbox TTS operations including
18
+ model management, text chunking, and audio generation.
19
  """
20
+ def __init__(self, device=DEFAULT_DEVICE):
21
+ self.device = device
22
+ self.model = None
23
+ self.sr = 24000 # Default Chatterbox SR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def load_model(self):
26
+ """Loads the model into memory if not already present."""
27
+ if self.model is None:
28
+ print(f"Loading Chatterbox TTS on {self.device}...")
29
+ self.model = ChatterboxTTS.from_pretrained(self.device)
30
+ self.sr = self.model.sr
31
+ return self.model
 
 
32
 
33
+ def set_seed(self, seed: int):
34
+ """Sets deterministic seeds for reproducibility."""
35
+ if seed == 0:
36
+ seed = random.randint(1, 1000000)
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed(seed)
39
+ torch.cuda.manual_seed_all(seed)
40
+ random.seed(seed)
41
+ np.random.seed(seed)
42
+ return seed
43
+
44
+ def chunk_text(self, text):
45
+ """
46
+ Splits text into chunks at sentence boundaries for long script handling.
47
+ """
48
+ # Split by punctuation followed by space
49
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
50
+ chunks = []
51
+ current_chunk = ""
52
 
53
+ for sentence in sentences:
54
+ if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS:
55
+ current_chunk += (sentence + " ")
56
+ else:
57
+ if current_chunk:
58
+ chunks.append(current_chunk.strip())
59
+
60
+ # Handle single sentences longer than MAX_CHUNK_CHARS
61
+ if len(sentence) > MAX_CHUNK_CHARS:
62
+ sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
63
+ temp = ""
64
+ for part in sub_parts:
65
+ if len(temp) + len(part) <= MAX_CHUNK_CHARS:
66
+ temp += (part + " ")
67
+ else:
68
+ if temp: chunks.append(temp.strip())
69
+ temp = part + " "
70
+ current_chunk = temp
71
+ else:
72
+ current_chunk = sentence + " "
73
+
74
+ if current_chunk:
75
+ chunks.append(current_chunk.strip())
76
+ return chunks
77
+
78
+ def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=None):
79
+ """
80
+ Processes the full script by chunking and concatenating results.
81
+ """
82
+ self.load_model()
83
+ actual_seed = self.set_seed(int(seed))
84
+ chunks = self.chunk_text(text)
85
 
86
+ if not chunks:
87
+ raise ValueError("No valid text provided.")
88
+ if ref_audio is None:
89
+ raise ValueError("Reference audio is required for cloning.")
90
+
91
+ all_wavs = []
92
  for i, chunk in enumerate(chunks):
93
+ if progress:
94
+ progress((i / len(chunks)), desc=f"Processing chunk {i+1}/{len(chunks)}")
95
 
96
+ wav = self.model.generate(
 
 
97
  chunk,
98
  audio_prompt_path=ref_audio,
99
  exaggeration=exaggeration,
 
101
  cfg_weight=cfg_weight
102
  )
103
 
 
104
  if wav.dim() == 1:
105
  wav = wav.unsqueeze(0)
 
106
  all_wavs.append(wav.cpu())
107
+
 
 
 
 
108
  final_wav = torch.cat(all_wavs, dim=-1)
109
 
110
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
111
+ output_path = tmp.name
112
+ torchaudio.save(output_path, final_wav, self.sr)
113
+
114
+ return output_path, actual_seed
115
+
116
+ # Initialize the engine
117
+ engine = VoiceCloningEngine()
118
+
119
+ def process_tts(text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=gr.Progress()):
120
+ try:
121
+ path, used_seed = engine.generate(text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress)
122
+ return path, f"Success! Seed used: {used_seed}"
123
  except Exception as e:
124
+ return None, f"Error: {str(e)}"
 
 
125
 
126
+ # UI Construction
127
+ with gr.Blocks(theme=gr.themes.Soft(), title="Chatterbox Voice Clone") as demo:
128
+ gr.Markdown("# 🗣️ Voice Cloning TTS Engine")
129
+ gr.Markdown("Optimized for long scripts with intelligent chunking and smooth concatenation.")
130
 
131
+ with gr.Row():
132
+ with gr.Column(scale=1):
133
+ text_input = gr.Textbox(label="Script", lines=8, placeholder="Enter long text here...")
134
+ ref_audio = gr.Audio(label="Reference Voice", type="filepath")
135
+
136
+ with gr.Row():
137
+ exag = gr.Slider(0.1, 1.0, value=0.5, label="Exaggeration", info="Warning: >0.8 can be unstable")
138
+ cfg = gr.Slider(0.0, 1.0, value=0.5, label="CFG/Pace")
139
+
140
+ with gr.Accordion("Advanced Options", open=False):
141
+ seed_val = gr.Number(label="Seed", value=0, precision=0, info="0 for random")
142
+ temp_val = gr.Slider(0.1, 2.0, value=1.0, label="Temperature")
143
+
144
+ btn = gr.Button("Generate", variant="primary")
145
+
146
+ with gr.Column(scale=1):
147
+ audio_out = gr.Audio(label="Generated Audio", type="filepath")
148
+ status = gr.Textbox(label="Status", interactive=False)
149
+
150
+ gr.Markdown("### 📖 Quick Guide")
151
+ gr.Markdown("""
152
+ - **Chunking**: Sentences are automatically split at ~250 chars.
153
+ - **Secrets**: Use HF Secrets for API keys if needed.
154
+ - **Pacing**: Lower CFG for slower, more deliberate speech.
155
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ btn.click(process_tts, [text_input, ref_audio, exag, cfg, temp_val, seed_val], [audio_out, status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  if __name__ == "__main__":
160
+ demo.launch(server_name="0.0.0.0")