Chatterbox / app.py
KremerM
Optional add pauses before merging
7a0996e
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces
import os
import re
import torchaudio
import threading
import time
from queue import Queue
from dataclasses import dataclass
from typing import Optional, Callable
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# Directory for saving audio files
OUTPUT_DIR = "generated_audio"
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
# Global variables for tracking sections and paragraphs
SECTION_INFO = [] # Will contain (section_number, paragraph_number, text) tuples
GENERATION_COUNTS = {} # Track generation count for each paragraph
# --- Global Model Initialization ---
MODEL = None
def get_or_load_model():
"""Loads the ChatterboxTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def count_words(text):
"""Count the number of words in a text string."""
return len([word for word in text.split() if word.strip()])
def trim_audio(audio_data, start_time, end_time):
"""
Trim audio data between start_time and end_time (in seconds)
Args:
audio_data: Tuple of (sample_rate, audio_array)
start_time: Start time in seconds
end_time: End time in seconds (None means end of audio)
Returns:
Trimmed audio data in same format
"""
if audio_data is None:
return None
sr, audio_array = audio_data
# Convert to numpy if needed
if isinstance(audio_array, torch.Tensor):
audio_array = audio_array.numpy()
# Calculate sample indices
start_sample = int(start_time * sr)
end_sample = int(end_time * sr) if end_time is not None else len(audio_array)
# Ensure bounds are valid
start_sample = max(0, start_sample)
end_sample = min(len(audio_array), end_sample)
if start_sample >= end_sample:
return None
# Trim the audio
trimmed_audio = audio_array[start_sample:end_sample]
return (sr, trimmed_audio)
def detect_silence_boundaries(audio_data, silence_threshold=0.01, min_silence_duration=0.1):
"""
Detect silence at the beginning and end of audio to suggest trim points
Args:
audio_data: Tuple of (sample_rate, audio_array)
silence_threshold: Amplitude threshold below which audio is considered silence
min_silence_duration: Minimum duration of silence to consider (seconds)
Returns:
(suggested_start, suggested_end) in seconds
"""
if audio_data is None:
return 0, 0
sr, audio_array = audio_data
# Convert to numpy if needed
if isinstance(audio_array, torch.Tensor):
audio_array = audio_array.numpy()
# Get absolute values
audio_abs = np.abs(audio_array)
# Find first non-silent sample
min_silence_samples = int(min_silence_duration * sr)
# Find start of audio content
start_idx = 0
for i in range(len(audio_abs) - min_silence_samples):
if np.max(audio_abs[i:i + min_silence_samples]) > silence_threshold:
start_idx = max(0, i - int(0.05 * sr)) # Add 50ms buffer
break
# Find end of audio content
end_idx = len(audio_abs)
for i in range(len(audio_abs) - min_silence_samples - 1, min_silence_samples, -1):
if np.max(audio_abs[i - min_silence_samples:i]) > silence_threshold:
end_idx = min(len(audio_abs), i + int(0.05 * sr)) # Add 50ms buffer
break
return start_idx / sr, end_idx / sr
def merge_audio_files(audio_data_list, crossfade_duration=0.1, pause_duration=0.0):
"""
Merge multiple audio files into one with optional crossfading and pauses
Args:
audio_data_list: List of (sample_rate, audio_array) tuples or None values
crossfade_duration: Duration of crossfade between clips in seconds
pause_duration: Duration of silence to add between clips in seconds
Returns:
Merged audio data as (sample_rate, audio_array) tuple
"""
# Filter out None values and collect valid audio data
valid_audio = []
sample_rate = None
print(f"Processing {len(audio_data_list)} audio clips for merging...")
for i, audio_data in enumerate(audio_data_list):
if audio_data is not None:
sr, audio_array = audio_data
if sample_rate is None:
sample_rate = sr
elif sample_rate != sr:
print(f"Warning: Sample rate mismatch. Expected {sample_rate}, got {sr}")
continue
# Convert to numpy if needed
if isinstance(audio_array, torch.Tensor):
audio_array = audio_array.numpy()
# Convert to float32 for processing to avoid casting errors
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)
duration = len(audio_array) / sr
print(f"Audio clip {i}: {duration:.2f} seconds, {len(audio_array)} samples")
valid_audio.append(audio_array)
if not valid_audio:
print("No valid audio found")
return None
if len(valid_audio) == 1:
print("Only one audio clip, returning as-is")
return (sample_rate, valid_audio[0])
print(f"Merging {len(valid_audio)} audio clips with {crossfade_duration}s crossfade and {pause_duration}s pause")
# Calculate crossfade and pause samples
crossfade_samples = int(crossfade_duration * sample_rate)
pause_samples = int(pause_duration * sample_rate)
print(f"Crossfade samples: {crossfade_samples}, Pause samples: {pause_samples}")
# Start with the first audio clip
merged_audio = valid_audio[0].copy()
print(f"Starting with clip 0: {len(merged_audio)} samples")
# Add each subsequent clip with crossfading and/or pauses
for i in range(1, len(valid_audio)):
current_clip = valid_audio[i]
print(f"Merging clip {i}: {len(current_clip)} samples")
# If we have both crossfade and pause, crossfade takes precedence
if crossfade_samples > 0 and len(merged_audio) >= crossfade_samples and len(current_clip) >= crossfade_samples:
print(f"Applying crossfade between clips")
# Create crossfade
# Fade out the end of merged_audio
fade_out = np.linspace(1.0, 0.0, crossfade_samples, dtype=np.float32)
merged_audio[-crossfade_samples:] *= fade_out
# Fade in the beginning of current_clip
fade_in = np.linspace(0.0, 1.0, crossfade_samples, dtype=np.float32)
current_clip_faded = current_clip.copy()
current_clip_faded[:crossfade_samples] *= fade_in
# Overlap the crossfade region
merged_audio[-crossfade_samples:] += current_clip_faded[:crossfade_samples]
# Append the rest of the current clip
merged_audio = np.concatenate([merged_audio, current_clip[crossfade_samples:]])
elif pause_samples > 0:
print(f"Adding {pause_duration}s pause between clips")
# Create silence for the pause
silence = np.zeros(pause_samples, dtype=np.float32)
# Add pause then the current clip
merged_audio = np.concatenate([merged_audio, silence, current_clip])
else:
print(f"No crossfade or pause, concatenating directly")
# No crossfade or pause, just concatenate
merged_audio = np.concatenate([merged_audio, current_clip])
print(f"Merged audio now: {len(merged_audio)} samples ({len(merged_audio)/sample_rate:.2f} seconds)")
final_duration = len(merged_audio) / sample_rate
print(f"Final merged audio: {len(merged_audio)} samples ({final_duration:.2f} seconds)")
return (sample_rate, merged_audio)
def save_merged_audio(merged_audio_data, section_filter=None):
"""
Save merged audio with appropriate naming
Args:
merged_audio_data: (sample_rate, audio_array) tuple
section_filter: Optional section number to filter by, or None for all sections
Returns:
Filepath of saved merged audio
"""
global OUTPUT_DIR, SECTION_INFO
if merged_audio_data is None:
return None
# Create filename based on what was merged
if section_filter is not None:
filename = f"merged_section_{section_filter}.wav"
else:
filename = "merged_all_sections.wav"
filepath = os.path.join(OUTPUT_DIR, filename)
# Save audio file
sr, audio_array = merged_audio_data
if isinstance(audio_array, np.ndarray):
audio_tensor = torch.tensor(audio_array)
else:
audio_tensor = audio_array
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
torchaudio.save(filepath, audio_tensor, sr)
print(f"Saved merged audio to {filepath}")
return filepath
def merge_all_generated_audio(crossfade_duration, pause_duration, *audio_inputs):
"""
Merge all generated audio files into one
"""
global SECTION_INFO
# Collect all non-None audio data regardless of section info
audio_data_list = []
merged_count = 0
print(f"Checking {len(audio_inputs)} audio inputs for merging...")
for idx, audio_data in enumerate(audio_inputs):
if audio_data is not None:
print(f"Found audio at index {idx}")
audio_data_list.append(audio_data)
merged_count += 1
else:
print(f"No audio at index {idx}")
print(f"Total audio files found: {merged_count}")
if not audio_data_list:
return None, "No audio files to merge."
# If we have section info, try to organize by section order
if SECTION_INFO:
# Create a mapping of section order
organized_audio = []
section_audio_map = {}
# First, map audio to their corresponding sections
for idx, (section_num, para_num, text) in enumerate(SECTION_INFO):
if idx < len(audio_inputs) and audio_inputs[idx] is not None:
if section_num not in section_audio_map:
section_audio_map[section_num] = []
section_audio_map[section_num].append((para_num, audio_inputs[idx]))
# Now organize by section number, then paragraph number
for section_num in sorted(section_audio_map.keys()):
# Sort paragraphs within each section
section_paragraphs = sorted(section_audio_map[section_num], key=lambda x: x[0])
for para_num, audio_data in section_paragraphs:
organized_audio.append(audio_data)
print(f"Added Section {section_num}, Paragraph {para_num} to merge list")
if organized_audio:
audio_data_list = organized_audio
print(f"Organized {len(audio_data_list)} audio files by section order")
# Merge the audio
merged_audio = merge_audio_files(audio_data_list, crossfade_duration=crossfade_duration, pause_duration=pause_duration)
if merged_audio is None:
return None, "Failed to merge audio files."
# Calculate total duration for verification
sr, merged_array = merged_audio
duration_seconds = len(merged_array) / sr
# Save the merged file
filepath = save_merged_audio(merged_audio)
return merged_audio, f"Merged {merged_count} audio files. Total duration: {duration_seconds:.1f} seconds. Saved to {filepath}"
def merge_by_section(section_number, crossfade_duration, *audio_inputs):
"""
Merge audio files from a specific section
"""
global SECTION_INFO
if not SECTION_INFO:
return None, "No paragraphs processed."
# Collect audio data for the specified section
audio_data_list = []
merged_count = 0
for idx, (section_num, para_num, text) in enumerate(SECTION_INFO):
if section_num == section_number and idx < len(audio_inputs) and audio_inputs[idx] is not None:
audio_data_list.append(audio_inputs[idx])
merged_count += 1
if not audio_data_list:
return None, f"No audio files found for section {section_number}."
# Merge the audio
merged_audio = merge_audio_files(audio_data_list, crossfade_duration=crossfade_duration)
if merged_audio is None:
return None, f"Failed to merge audio files for section {section_number}."
# Save the merged file
filepath = save_merged_audio(merged_audio, section_filter=section_number)
return merged_audio, f"Merged {merged_count} audio files from section {section_number}. Saved to {filepath}"
def save_trimmed_audio(audio_data, index, is_trimmed=False):
"""
Save audio data to a file with appropriate naming
If is_trimmed=True, adds "_trimmed" to the filename
"""
global GENERATION_COUNTS, OUTPUT_DIR, SECTION_INFO
if audio_data is None:
return None
# Get section and paragraph information
if index < len(SECTION_INFO):
section_num, para_num, _ = SECTION_INFO[index]
else:
section_num = 1
para_num = index + 1
gen_count = GENERATION_COUNTS.get(index, 1)
# Create filename
if is_trimmed:
filename = f"{section_num}_{para_num}_{gen_count}_trimmed.wav"
else:
filename = f"{section_num}_{para_num}_{gen_count}.wav"
filepath = os.path.join(OUTPUT_DIR, filename)
# Save audio file
sr, audio_array = audio_data
if isinstance(audio_array, np.ndarray):
audio_tensor = torch.tensor(audio_array)
else:
audio_tensor = audio_array
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
torchaudio.save(filepath, audio_tensor, sr)
print(f"Saved audio to {filepath}")
return filepath
def split_text_into_sections_and_paragraphs(text):
"""
Split input text into sections and paragraphs.
A section is identified by a line that starts with a digit (first character is used as section number).
The line can contain additional text after the digit (e.g., "1 Introduction", "2. Chapter Two", "3 - Main Content").
Each non-empty line after that becomes a separate paragraph in that section until a new section is found.
If no section is specified at the beginning, section 1 is the default.
Empty lines are skipped.
Anything in square brackets [] is treated as comments and filtered out.
Returns:
- List of (section_number, paragraph_number, text) tuples
- Total word count
"""
import re
global SECTION_INFO
# Reset section info
SECTION_INFO = []
if not text.strip():
return SECTION_INFO, 0
# Split text into lines
lines = text.strip().split('\n')
# Initialize variables
current_section = None
current_para_in_section = 1
total_word_count = 0
for line in lines:
line = line.rstrip()
# Skip empty lines
if not line.strip():
continue
# Remove anything in square brackets (comments)
cleaned_line = re.sub(r'\[.*?\]', '', line).strip()
# Skip lines that become empty after removing comments
if not cleaned_line:
continue
# Check if this line starts with a digit (indicating a section marker)
# The first character must be a digit, but there can be additional text
if cleaned_line[0].isdigit():
# Extract the first character as the section number
section_num = int(cleaned_line[0])
# Update section and reset paragraph counter
current_section = section_num
current_para_in_section = 1
print(f"Found section marker: '{cleaned_line}' -> Section {section_num}")
else:
# This is a text line - treat as a complete paragraph
# If we haven't set a section yet, default to section 1
if current_section is None:
current_section = 1
# Add this line as a complete paragraph (using cleaned text)
paragraph_text = cleaned_line
SECTION_INFO.append((current_section, current_para_in_section, paragraph_text))
total_word_count += count_words(paragraph_text)
current_para_in_section += 1
return SECTION_INFO, total_word_count
def save_audio_file(audio_data, index):
"""
Save audio data to a file with appropriate naming: {Section}_{Paragraph}_{Generation}.wav
"""
global GENERATION_COUNTS, OUTPUT_DIR, SECTION_INFO
if audio_data is None:
return None
# Get section and paragraph information
if index < len(SECTION_INFO):
section_num, para_num, _ = SECTION_INFO[index]
else:
# Fallback if index is out of bounds
section_num = 1
para_num = index + 1
# Increment generation count for this paragraph
GENERATION_COUNTS[index] = GENERATION_COUNTS.get(index, 0) + 1
gen_count = GENERATION_COUNTS[index]
# Create filename
filename = f"{section_num}_{para_num}_{gen_count}.wav"
filepath = os.path.join(OUTPUT_DIR, filename)
# Save audio file
sr, audio_array = audio_data
# Ensure audio_array is properly formatted
if isinstance(audio_array, np.ndarray):
audio_tensor = torch.tensor(audio_array)
else:
audio_tensor = audio_array
# Add batch dimension if needed
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
torchaudio.save(filepath, audio_tensor, sr)
print(f"Saved audio to {filepath}")
return filepath
def generate_single_in_batch(text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=None):
"""Internal generation function that does the actual work - no model parameter to avoid pickling issues"""
model = get_or_load_model() # Get model inside the function
if not text.strip():
return None, 0, 0
if seed_num != 0:
actual_seed = int(seed_num)
else:
actual_seed = random.randint(0, 2**32 - 1)
print(f"Using seed: {actual_seed} for paragraph {paragraph_index}, exaggeration: {exaggeration}")
set_seed(actual_seed)
generate_kwargs = {
"exaggeration": exaggeration,
"temperature": temperature,
"cfg_weight": cfgw,
}
if audio_prompt_path:
generate_kwargs["audio_prompt_path"] = audio_prompt_path
wav = model.generate(
text[:1000], # Truncate text to max chars
**generate_kwargs
)
audio_data = (model.sr, wav.squeeze(0).numpy())
# Save the audio file if paragraph_index is provided
if paragraph_index is not None:
save_audio_file(audio_data, paragraph_index)
return audio_data
@spaces.GPU
def generate_single_internal(text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=None):
"""Internal generation function that does the actual work - no model parameter to avoid pickling issues"""
model = get_or_load_model() # Get model inside the function
if not text.strip():
return None, 0, 0
if seed_num != 0:
actual_seed = int(seed_num)
else:
actual_seed = random.randint(0, 2**32 - 1)
print(f"Using seed: {actual_seed} for paragraph {paragraph_index}, exaggeration: {exaggeration}")
set_seed(actual_seed)
generate_kwargs = {
"exaggeration": exaggeration,
"temperature": temperature,
"cfg_weight": cfgw,
}
if audio_prompt_path:
generate_kwargs["audio_prompt_path"] = audio_prompt_path
wav = model.generate(
text[:500], # Truncate text to max chars
**generate_kwargs
)
audio_data = (model.sr, wav.squeeze(0).numpy())
# Save the audio file if paragraph_index is provided
if paragraph_index is not None:
save_audio_file(audio_data, paragraph_index)
return audio_data
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.6
) -> tuple[int, np.ndarray]:
"""
Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
This tool synthesizes natural-sounding speech from input text. When a reference audio file
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
Args:
text_input (str): The text to synthesize into speech (maximum 300 characters)
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.6.
Returns:
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
# Handle optional audio prompt
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
}
if audio_prompt_path_input:
generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
wav = current_model.generate(
text_input[:300], # Truncate text to max chars
**generate_kwargs
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
@spaces.GPU(duration=600)
def generate_all_sequential(ref_wav_path, temperature, seed_num, cfgw, *exaggeration_values):
"""Generate audio for all paragraphs sequentially using individual exaggeration values"""
global SECTION_INFO
if not SECTION_INFO:
return ["No paragraphs to process. Please process a script first."] + [None] * 100 + [0] * 100
status_messages = []
audio_results = []
# Initialize results lists with None values for all possible slots
MAX_PARAGRAPHS = 50
for i in range(MAX_PARAGRAPHS):
audio_results.append(None)
for idx, (section_num, para_num, text) in enumerate(SECTION_INFO):
if text.strip():
# Get the exaggeration value for this paragraph
exaggeration = exaggeration_values[idx] if idx < len(exaggeration_values) else 0.35
print(f"Generating Section {section_num}, Paragraph {para_num} [{idx+1}/{len(SECTION_INFO)}] with exaggeration {exaggeration}...")
status_messages.append(f"Processing Section {section_num}, Paragraph {para_num} (exaggeration: {exaggeration})...")
# Call the spaces.GPU decorated function directly
audio = generate_single_in_batch(text, ref_wav_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=idx)
if audio:
status_messages.append(f"✓ Generated audio for Section {section_num}, Paragraph {para_num}")
# Store the audio result in the correct slot
audio_results[idx] = audio
else:
status_messages.append(f"✗ Failed to generate audio for Section {section_num}, Paragraph {para_num}")
final_status = "\n".join(status_messages) + f"\n\nCompleted processing {len(SECTION_INFO)} paragraphs!"
# Return status plus all audio results
return [final_status] + audio_results
def apply_trim_and_save(audio_data, start_time, end_time, paragraph_index):
"""Apply trimming and save the trimmed audio"""
if audio_data is None:
return None, "No audio to trim"
try:
trimmed_audio = trim_audio(audio_data, start_time, end_time)
if trimmed_audio is None:
return None, "Invalid trim parameters"
# Save trimmed version
filepath = save_trimmed_audio(trimmed_audio, paragraph_index, is_trimmed=True)
return trimmed_audio, f"Trimmed and saved to {filepath}"
except Exception as e:
return None, f"Error trimming audio: {str(e)}"
def update_paragraph_ui(script_text):
"""
Process input script and update paragraph text fields and buttons
"""
global SECTION_INFO, GENERATION_COUNTS
# Reset generation counts
GENERATION_COUNTS = {}
section_info, word_count = split_text_into_sections_and_paragraphs(script_text)
# Initialize generation count for each paragraph
for i in range(len(section_info)):
GENERATION_COUNTS[i] = 0
# Create updates for all potential paragraph fields
MAX_PARAGRAPHS = 50
text_updates = []
row_updates = []
button_updates = []
audio_updates = []
label_updates = []
exaggeration_updates = []
for i in range(MAX_PARAGRAPHS):
if i < len(section_info):
section_num, para_num, text = section_info[i]
text_updates.append(gr.update(value=text, visible=True))
row_updates.append(gr.update(visible=True))
button_updates.append(gr.update(visible=True))
audio_updates.append(gr.update(visible=True))
label_updates.append(gr.update(value=f"Section {section_num}, Paragraph {para_num} ({count_words(text)} words)", visible=True))
exaggeration_updates.append(gr.update(value=0.35, visible=True))
else:
text_updates.append(gr.update(value="", visible=False))
row_updates.append(gr.update(visible=False))
button_updates.append(gr.update(visible=False))
audio_updates.append(gr.update(visible=False))
label_updates.append(gr.update(value="", visible=False))
exaggeration_updates.append(gr.update(value=0.35, visible=False))
# Update paragraph count and word count
count_update = gr.update(value=f"Total Paragraphs: {len(section_info)} | Total Words: {word_count}")
# Generate a summary of sections and paragraph counts
section_counts = {}
section_word_counts = {}
for section_num, para_num, text in section_info:
section_counts[section_num] = max(section_counts.get(section_num, 0), para_num)
section_word_counts[section_num] = section_word_counts.get(section_num, 0) + count_words(text)
section_summary = "Sections: " + ", ".join([f"Section {s}: {c} paragraphs ({section_word_counts[s]} words)" for s, c in sorted(section_counts.items())])
summary_update = gr.update(value=section_summary)
# Return all updates
return (text_updates + row_updates + button_updates + audio_updates + label_updates +
exaggeration_updates +
[count_update, summary_update])
def apply_global_exaggeration(global_value):
"""Apply the same exaggeration value to all visible paragraph sliders"""
global SECTION_INFO
updates = []
for i in range(50): # MAX_PARAGRAPHS
if i < len(SECTION_INFO):
# Update visible sliders with the global value
updates.append(gr.update(value=global_value))
else:
# Keep invisible sliders unchanged
updates.append(gr.update())
return updates
with gr.Blocks() as demo:
# Storage for dynamic components
paragraph_texts = []
paragraph_rows = []
paragraph_labels = []
generate_buttons = []
audio_outputs = []
exaggeration_sliders = []
gr.Markdown(
"""
# Chatterbox TTS Demo with Script Processing
Generate high-quality speech from text with reference audio styling and batch processing capabilities.
"""
)
with gr.Tab("Single Generation"):
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="Text to synthesize (max chars 300)",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
exaggeration = gr.Slider(
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05, label="CFG/Pace", value=0.6
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=131789919, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output],
)
with gr.Tab("Script Processing"):
with gr.Row():
with gr.Column():
script_input = gr.Textbox(
label="Script Input (section headers start with a digit, e.g., '1 Introduction', '2. Chapter Two')",
placeholder="Enter your script here. Use lines starting with digits to mark sections (e.g., '1 Introduction', '2. Chapter Two', '3 - Main Content'). The first character determines the section number. Separate paragraphs with blank lines.",
lines=10,
value="1 Introduction\nThis is the first paragraph of section 1. It contains some interesting content.\n\nThis is the second paragraph of section 1. It's separated by a blank line.\n\n2. Chapter Two\nThis is the first paragraph of section 2. Notice the '2.' above marks the section.\n\nThis is the second paragraph of section 2.\n\n3 - Final Section\nThis is the first paragraph of section 3."
)
process_btn = gr.Button("Process Script", variant="primary")
paragraph_count = gr.Markdown("Total Paragraphs: 0 | Total Words: 0")
section_summary = gr.Markdown("Sections: None")
output_dir_info = gr.Markdown(f"Audio files will be saved to: {os.path.abspath(OUTPUT_DIR)}")
# Generation parameters# Generation parameters
with gr.Row():
with gr.Column():
gr.Markdown("## Global Generation Settings")
script_ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac")
script_cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.6)
with gr.Accordion("More options", open=False):
script_seed_num = gr.Number(value=131789919, label="Random seed (0 for random)")
script_temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
# NEW: Global Exaggeration Control
gr.Markdown("### Global Exaggeration Control")
global_exaggeration = gr.Slider(
0.25, 2,
step=0.05,
label="Set All Exaggeration Values",
value=0.35
)
apply_global_exaggeration_btn = gr.Button("Apply to All Paragraphs", variant="secondary")
generate_all_btn = gr.Button("Generate All Paragraphs", variant="primary", size="lg")
generation_status = gr.Textbox(label="Generation Status", lines=5, interactive=False)
# Paragraphs section
gr.Markdown("## Paragraphs")
# Create placeholders for paragraph entries
MAX_PARAGRAPHS = 50
for i in range(MAX_PARAGRAPHS):
with gr.Row(visible=False) as row:
paragraph_rows.append(row)
with gr.Column(scale=4):
paragraph_label = gr.Markdown("Section 1, Paragraph 1", visible=False)
paragraph_labels.append(paragraph_label)
text_input = gr.Textbox(
lines=3,
max_lines=5,
visible=False
)
paragraph_texts.append(text_input)
with gr.Column(scale=2):
exaggeration_slider = gr.Slider(
0.25, 2,
step=0.05,
label="Exaggeration",
value=0.35,
visible=False
)
exaggeration_sliders.append(exaggeration_slider)
generate_btn = gr.Button(f"Generate", visible=False)
generate_buttons.append(generate_btn)
with gr.Column(scale=3):
audio_output = gr.Audio(label=f"Generated Audio", type="numpy", autoplay=False, visible=False, show_download_button=True, interactive=True)
audio_outputs.append(audio_output)
# Setup individual paragraph generation
def make_generate_handler(idx):
def generate_handler(text, ref_wav, exag, temp, seed, cfg):
# Call the spaces.GPU function directly without passing model
audio = generate_single_internal(text, ref_wav, exag, temp, seed, cfg, idx)
return audio
return generate_handler
generate_btn.click(
fn=make_generate_handler(i),
inputs=[
text_input,
script_ref_wav,
exaggeration_slider,
script_temp,
script_seed_num,
script_cfg_weight,
],
outputs=[audio_output],
)
# Audio Merging section
gr.Markdown("## Audio Merging")
with gr.Row():
with gr.Column():
gr.Markdown("### Merge Options")
crossfade_duration = gr.Slider(
0.0, 1.0,
step=0.05,
value=0.0,
label="Crossfade Duration (seconds)"
)
pause_duration = gr.Slider(0, 2, step=0.1, label="Pause Between Segments (seconds). Crossfade must be 0", value=0.3)
with gr.Row():
merge_all_btn = gr.Button("Merge All Audio", variant="primary")
section_number_input = gr.Number(
value=1,
label="Section Number",
precision=0
)
merge_section_btn = gr.Button("Merge Section", variant="secondary")
merge_status = gr.Textbox(
label="Merge Status",
lines=2,
interactive=False
)
with gr.Column():
merged_audio_output = gr.Audio(
label="Merged Audio",
type="numpy",
show_download_button=True
)
# Setup process button to update paragraphs
process_btn.click(
fn=update_paragraph_ui,
inputs=[script_input],
outputs=(paragraph_texts + paragraph_rows + generate_buttons + audio_outputs +
paragraph_labels + exaggeration_sliders +
[paragraph_count, section_summary]),
)
# Setup generate all button - now with all audio outputs
generate_all_btn.click(
fn=generate_all_sequential,
inputs=[
script_ref_wav,
script_temp,
script_seed_num,
script_cfg_weight,
] + exaggeration_sliders,
outputs=[generation_status] + audio_outputs,
)
# Setup merge all button
merge_all_btn.click(
fn=merge_all_generated_audio,
inputs=[crossfade_duration, pause_duration] + audio_outputs,
outputs=[merged_audio_output, merge_status]
)
# Setup merge section button
def merge_section_handler(section_num, crossfade_duration, *audio_inputs):
return merge_by_section(int(section_num), crossfade_duration, *audio_inputs)
merge_section_btn.click(
fn=merge_section_handler,
inputs=[section_number_input, crossfade_duration] + audio_outputs,
outputs=[merged_audio_output, merge_status]
)
# Setup global exaggeration apply button
apply_global_exaggeration_btn.click(
fn=apply_global_exaggeration,
inputs=[global_exaggeration],
outputs=exaggeration_sliders
)
demo.launch(mcp_server=True)