Spaces:
Sleeping
Sleeping
| import random | |
| import numpy as np | |
| import torch | |
| from chatterbox.src.chatterbox.tts import ChatterboxTTS | |
| import gradio as gr | |
| import spaces | |
| import os | |
| import re | |
| import torchaudio | |
| import threading | |
| import time | |
| from queue import Queue | |
| from dataclasses import dataclass | |
| from typing import Optional, Callable | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🚀 Running on device: {DEVICE}") | |
| # Directory for saving audio files | |
| OUTPUT_DIR = "generated_audio" | |
| if not os.path.exists(OUTPUT_DIR): | |
| os.makedirs(OUTPUT_DIR) | |
| # Global variables for tracking sections and paragraphs | |
| SECTION_INFO = [] # Will contain (section_number, paragraph_number, text) tuples | |
| GENERATION_COUNTS = {} # Track generation count for each paragraph | |
| # --- Global Model Initialization --- | |
| MODEL = None | |
| def get_or_load_model(): | |
| """Loads the ChatterboxTTS model if it hasn't been loaded already, | |
| and ensures it's on the correct device.""" | |
| global MODEL | |
| if MODEL is None: | |
| print("Model not loaded, initializing...") | |
| try: | |
| MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
| if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: | |
| MODEL.to(DEVICE) | |
| print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| return MODEL | |
| # Attempt to load the model at startup. | |
| try: | |
| get_or_load_model() | |
| except Exception as e: | |
| print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") | |
| def set_seed(seed: int): | |
| """Sets the random seed for reproducibility across torch, numpy, and random.""" | |
| torch.manual_seed(seed) | |
| if DEVICE == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def count_words(text): | |
| """Count the number of words in a text string.""" | |
| return len([word for word in text.split() if word.strip()]) | |
| def trim_audio(audio_data, start_time, end_time): | |
| """ | |
| Trim audio data between start_time and end_time (in seconds) | |
| Args: | |
| audio_data: Tuple of (sample_rate, audio_array) | |
| start_time: Start time in seconds | |
| end_time: End time in seconds (None means end of audio) | |
| Returns: | |
| Trimmed audio data in same format | |
| """ | |
| if audio_data is None: | |
| return None | |
| sr, audio_array = audio_data | |
| # Convert to numpy if needed | |
| if isinstance(audio_array, torch.Tensor): | |
| audio_array = audio_array.numpy() | |
| # Calculate sample indices | |
| start_sample = int(start_time * sr) | |
| end_sample = int(end_time * sr) if end_time is not None else len(audio_array) | |
| # Ensure bounds are valid | |
| start_sample = max(0, start_sample) | |
| end_sample = min(len(audio_array), end_sample) | |
| if start_sample >= end_sample: | |
| return None | |
| # Trim the audio | |
| trimmed_audio = audio_array[start_sample:end_sample] | |
| return (sr, trimmed_audio) | |
| def detect_silence_boundaries(audio_data, silence_threshold=0.01, min_silence_duration=0.1): | |
| """ | |
| Detect silence at the beginning and end of audio to suggest trim points | |
| Args: | |
| audio_data: Tuple of (sample_rate, audio_array) | |
| silence_threshold: Amplitude threshold below which audio is considered silence | |
| min_silence_duration: Minimum duration of silence to consider (seconds) | |
| Returns: | |
| (suggested_start, suggested_end) in seconds | |
| """ | |
| if audio_data is None: | |
| return 0, 0 | |
| sr, audio_array = audio_data | |
| # Convert to numpy if needed | |
| if isinstance(audio_array, torch.Tensor): | |
| audio_array = audio_array.numpy() | |
| # Get absolute values | |
| audio_abs = np.abs(audio_array) | |
| # Find first non-silent sample | |
| min_silence_samples = int(min_silence_duration * sr) | |
| # Find start of audio content | |
| start_idx = 0 | |
| for i in range(len(audio_abs) - min_silence_samples): | |
| if np.max(audio_abs[i:i + min_silence_samples]) > silence_threshold: | |
| start_idx = max(0, i - int(0.05 * sr)) # Add 50ms buffer | |
| break | |
| # Find end of audio content | |
| end_idx = len(audio_abs) | |
| for i in range(len(audio_abs) - min_silence_samples - 1, min_silence_samples, -1): | |
| if np.max(audio_abs[i - min_silence_samples:i]) > silence_threshold: | |
| end_idx = min(len(audio_abs), i + int(0.05 * sr)) # Add 50ms buffer | |
| break | |
| return start_idx / sr, end_idx / sr | |
| def merge_audio_files(audio_data_list, crossfade_duration=0.1, pause_duration=0.0): | |
| """ | |
| Merge multiple audio files into one with optional crossfading and pauses | |
| Args: | |
| audio_data_list: List of (sample_rate, audio_array) tuples or None values | |
| crossfade_duration: Duration of crossfade between clips in seconds | |
| pause_duration: Duration of silence to add between clips in seconds | |
| Returns: | |
| Merged audio data as (sample_rate, audio_array) tuple | |
| """ | |
| # Filter out None values and collect valid audio data | |
| valid_audio = [] | |
| sample_rate = None | |
| print(f"Processing {len(audio_data_list)} audio clips for merging...") | |
| for i, audio_data in enumerate(audio_data_list): | |
| if audio_data is not None: | |
| sr, audio_array = audio_data | |
| if sample_rate is None: | |
| sample_rate = sr | |
| elif sample_rate != sr: | |
| print(f"Warning: Sample rate mismatch. Expected {sample_rate}, got {sr}") | |
| continue | |
| # Convert to numpy if needed | |
| if isinstance(audio_array, torch.Tensor): | |
| audio_array = audio_array.numpy() | |
| # Convert to float32 for processing to avoid casting errors | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) | |
| duration = len(audio_array) / sr | |
| print(f"Audio clip {i}: {duration:.2f} seconds, {len(audio_array)} samples") | |
| valid_audio.append(audio_array) | |
| if not valid_audio: | |
| print("No valid audio found") | |
| return None | |
| if len(valid_audio) == 1: | |
| print("Only one audio clip, returning as-is") | |
| return (sample_rate, valid_audio[0]) | |
| print(f"Merging {len(valid_audio)} audio clips with {crossfade_duration}s crossfade and {pause_duration}s pause") | |
| # Calculate crossfade and pause samples | |
| crossfade_samples = int(crossfade_duration * sample_rate) | |
| pause_samples = int(pause_duration * sample_rate) | |
| print(f"Crossfade samples: {crossfade_samples}, Pause samples: {pause_samples}") | |
| # Start with the first audio clip | |
| merged_audio = valid_audio[0].copy() | |
| print(f"Starting with clip 0: {len(merged_audio)} samples") | |
| # Add each subsequent clip with crossfading and/or pauses | |
| for i in range(1, len(valid_audio)): | |
| current_clip = valid_audio[i] | |
| print(f"Merging clip {i}: {len(current_clip)} samples") | |
| # If we have both crossfade and pause, crossfade takes precedence | |
| if crossfade_samples > 0 and len(merged_audio) >= crossfade_samples and len(current_clip) >= crossfade_samples: | |
| print(f"Applying crossfade between clips") | |
| # Create crossfade | |
| # Fade out the end of merged_audio | |
| fade_out = np.linspace(1.0, 0.0, crossfade_samples, dtype=np.float32) | |
| merged_audio[-crossfade_samples:] *= fade_out | |
| # Fade in the beginning of current_clip | |
| fade_in = np.linspace(0.0, 1.0, crossfade_samples, dtype=np.float32) | |
| current_clip_faded = current_clip.copy() | |
| current_clip_faded[:crossfade_samples] *= fade_in | |
| # Overlap the crossfade region | |
| merged_audio[-crossfade_samples:] += current_clip_faded[:crossfade_samples] | |
| # Append the rest of the current clip | |
| merged_audio = np.concatenate([merged_audio, current_clip[crossfade_samples:]]) | |
| elif pause_samples > 0: | |
| print(f"Adding {pause_duration}s pause between clips") | |
| # Create silence for the pause | |
| silence = np.zeros(pause_samples, dtype=np.float32) | |
| # Add pause then the current clip | |
| merged_audio = np.concatenate([merged_audio, silence, current_clip]) | |
| else: | |
| print(f"No crossfade or pause, concatenating directly") | |
| # No crossfade or pause, just concatenate | |
| merged_audio = np.concatenate([merged_audio, current_clip]) | |
| print(f"Merged audio now: {len(merged_audio)} samples ({len(merged_audio)/sample_rate:.2f} seconds)") | |
| final_duration = len(merged_audio) / sample_rate | |
| print(f"Final merged audio: {len(merged_audio)} samples ({final_duration:.2f} seconds)") | |
| return (sample_rate, merged_audio) | |
| def save_merged_audio(merged_audio_data, section_filter=None): | |
| """ | |
| Save merged audio with appropriate naming | |
| Args: | |
| merged_audio_data: (sample_rate, audio_array) tuple | |
| section_filter: Optional section number to filter by, or None for all sections | |
| Returns: | |
| Filepath of saved merged audio | |
| """ | |
| global OUTPUT_DIR, SECTION_INFO | |
| if merged_audio_data is None: | |
| return None | |
| # Create filename based on what was merged | |
| if section_filter is not None: | |
| filename = f"merged_section_{section_filter}.wav" | |
| else: | |
| filename = "merged_all_sections.wav" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| # Save audio file | |
| sr, audio_array = merged_audio_data | |
| if isinstance(audio_array, np.ndarray): | |
| audio_tensor = torch.tensor(audio_array) | |
| else: | |
| audio_tensor = audio_array | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| torchaudio.save(filepath, audio_tensor, sr) | |
| print(f"Saved merged audio to {filepath}") | |
| return filepath | |
| def merge_all_generated_audio(crossfade_duration, pause_duration, *audio_inputs): | |
| """ | |
| Merge all generated audio files into one | |
| """ | |
| global SECTION_INFO | |
| # Collect all non-None audio data regardless of section info | |
| audio_data_list = [] | |
| merged_count = 0 | |
| print(f"Checking {len(audio_inputs)} audio inputs for merging...") | |
| for idx, audio_data in enumerate(audio_inputs): | |
| if audio_data is not None: | |
| print(f"Found audio at index {idx}") | |
| audio_data_list.append(audio_data) | |
| merged_count += 1 | |
| else: | |
| print(f"No audio at index {idx}") | |
| print(f"Total audio files found: {merged_count}") | |
| if not audio_data_list: | |
| return None, "No audio files to merge." | |
| # If we have section info, try to organize by section order | |
| if SECTION_INFO: | |
| # Create a mapping of section order | |
| organized_audio = [] | |
| section_audio_map = {} | |
| # First, map audio to their corresponding sections | |
| for idx, (section_num, para_num, text) in enumerate(SECTION_INFO): | |
| if idx < len(audio_inputs) and audio_inputs[idx] is not None: | |
| if section_num not in section_audio_map: | |
| section_audio_map[section_num] = [] | |
| section_audio_map[section_num].append((para_num, audio_inputs[idx])) | |
| # Now organize by section number, then paragraph number | |
| for section_num in sorted(section_audio_map.keys()): | |
| # Sort paragraphs within each section | |
| section_paragraphs = sorted(section_audio_map[section_num], key=lambda x: x[0]) | |
| for para_num, audio_data in section_paragraphs: | |
| organized_audio.append(audio_data) | |
| print(f"Added Section {section_num}, Paragraph {para_num} to merge list") | |
| if organized_audio: | |
| audio_data_list = organized_audio | |
| print(f"Organized {len(audio_data_list)} audio files by section order") | |
| # Merge the audio | |
| merged_audio = merge_audio_files(audio_data_list, crossfade_duration=crossfade_duration, pause_duration=pause_duration) | |
| if merged_audio is None: | |
| return None, "Failed to merge audio files." | |
| # Calculate total duration for verification | |
| sr, merged_array = merged_audio | |
| duration_seconds = len(merged_array) / sr | |
| # Save the merged file | |
| filepath = save_merged_audio(merged_audio) | |
| return merged_audio, f"Merged {merged_count} audio files. Total duration: {duration_seconds:.1f} seconds. Saved to {filepath}" | |
| def merge_by_section(section_number, crossfade_duration, *audio_inputs): | |
| """ | |
| Merge audio files from a specific section | |
| """ | |
| global SECTION_INFO | |
| if not SECTION_INFO: | |
| return None, "No paragraphs processed." | |
| # Collect audio data for the specified section | |
| audio_data_list = [] | |
| merged_count = 0 | |
| for idx, (section_num, para_num, text) in enumerate(SECTION_INFO): | |
| if section_num == section_number and idx < len(audio_inputs) and audio_inputs[idx] is not None: | |
| audio_data_list.append(audio_inputs[idx]) | |
| merged_count += 1 | |
| if not audio_data_list: | |
| return None, f"No audio files found for section {section_number}." | |
| # Merge the audio | |
| merged_audio = merge_audio_files(audio_data_list, crossfade_duration=crossfade_duration) | |
| if merged_audio is None: | |
| return None, f"Failed to merge audio files for section {section_number}." | |
| # Save the merged file | |
| filepath = save_merged_audio(merged_audio, section_filter=section_number) | |
| return merged_audio, f"Merged {merged_count} audio files from section {section_number}. Saved to {filepath}" | |
| def save_trimmed_audio(audio_data, index, is_trimmed=False): | |
| """ | |
| Save audio data to a file with appropriate naming | |
| If is_trimmed=True, adds "_trimmed" to the filename | |
| """ | |
| global GENERATION_COUNTS, OUTPUT_DIR, SECTION_INFO | |
| if audio_data is None: | |
| return None | |
| # Get section and paragraph information | |
| if index < len(SECTION_INFO): | |
| section_num, para_num, _ = SECTION_INFO[index] | |
| else: | |
| section_num = 1 | |
| para_num = index + 1 | |
| gen_count = GENERATION_COUNTS.get(index, 1) | |
| # Create filename | |
| if is_trimmed: | |
| filename = f"{section_num}_{para_num}_{gen_count}_trimmed.wav" | |
| else: | |
| filename = f"{section_num}_{para_num}_{gen_count}.wav" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| # Save audio file | |
| sr, audio_array = audio_data | |
| if isinstance(audio_array, np.ndarray): | |
| audio_tensor = torch.tensor(audio_array) | |
| else: | |
| audio_tensor = audio_array | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| torchaudio.save(filepath, audio_tensor, sr) | |
| print(f"Saved audio to {filepath}") | |
| return filepath | |
| def split_text_into_sections_and_paragraphs(text): | |
| """ | |
| Split input text into sections and paragraphs. | |
| A section is identified by a line that starts with a digit (first character is used as section number). | |
| The line can contain additional text after the digit (e.g., "1 Introduction", "2. Chapter Two", "3 - Main Content"). | |
| Each non-empty line after that becomes a separate paragraph in that section until a new section is found. | |
| If no section is specified at the beginning, section 1 is the default. | |
| Empty lines are skipped. | |
| Anything in square brackets [] is treated as comments and filtered out. | |
| Returns: | |
| - List of (section_number, paragraph_number, text) tuples | |
| - Total word count | |
| """ | |
| import re | |
| global SECTION_INFO | |
| # Reset section info | |
| SECTION_INFO = [] | |
| if not text.strip(): | |
| return SECTION_INFO, 0 | |
| # Split text into lines | |
| lines = text.strip().split('\n') | |
| # Initialize variables | |
| current_section = None | |
| current_para_in_section = 1 | |
| total_word_count = 0 | |
| for line in lines: | |
| line = line.rstrip() | |
| # Skip empty lines | |
| if not line.strip(): | |
| continue | |
| # Remove anything in square brackets (comments) | |
| cleaned_line = re.sub(r'\[.*?\]', '', line).strip() | |
| # Skip lines that become empty after removing comments | |
| if not cleaned_line: | |
| continue | |
| # Check if this line starts with a digit (indicating a section marker) | |
| # The first character must be a digit, but there can be additional text | |
| if cleaned_line[0].isdigit(): | |
| # Extract the first character as the section number | |
| section_num = int(cleaned_line[0]) | |
| # Update section and reset paragraph counter | |
| current_section = section_num | |
| current_para_in_section = 1 | |
| print(f"Found section marker: '{cleaned_line}' -> Section {section_num}") | |
| else: | |
| # This is a text line - treat as a complete paragraph | |
| # If we haven't set a section yet, default to section 1 | |
| if current_section is None: | |
| current_section = 1 | |
| # Add this line as a complete paragraph (using cleaned text) | |
| paragraph_text = cleaned_line | |
| SECTION_INFO.append((current_section, current_para_in_section, paragraph_text)) | |
| total_word_count += count_words(paragraph_text) | |
| current_para_in_section += 1 | |
| return SECTION_INFO, total_word_count | |
| def save_audio_file(audio_data, index): | |
| """ | |
| Save audio data to a file with appropriate naming: {Section}_{Paragraph}_{Generation}.wav | |
| """ | |
| global GENERATION_COUNTS, OUTPUT_DIR, SECTION_INFO | |
| if audio_data is None: | |
| return None | |
| # Get section and paragraph information | |
| if index < len(SECTION_INFO): | |
| section_num, para_num, _ = SECTION_INFO[index] | |
| else: | |
| # Fallback if index is out of bounds | |
| section_num = 1 | |
| para_num = index + 1 | |
| # Increment generation count for this paragraph | |
| GENERATION_COUNTS[index] = GENERATION_COUNTS.get(index, 0) + 1 | |
| gen_count = GENERATION_COUNTS[index] | |
| # Create filename | |
| filename = f"{section_num}_{para_num}_{gen_count}.wav" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| # Save audio file | |
| sr, audio_array = audio_data | |
| # Ensure audio_array is properly formatted | |
| if isinstance(audio_array, np.ndarray): | |
| audio_tensor = torch.tensor(audio_array) | |
| else: | |
| audio_tensor = audio_array | |
| # Add batch dimension if needed | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| torchaudio.save(filepath, audio_tensor, sr) | |
| print(f"Saved audio to {filepath}") | |
| return filepath | |
| def generate_single_in_batch(text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=None): | |
| """Internal generation function that does the actual work - no model parameter to avoid pickling issues""" | |
| model = get_or_load_model() # Get model inside the function | |
| if not text.strip(): | |
| return None, 0, 0 | |
| if seed_num != 0: | |
| actual_seed = int(seed_num) | |
| else: | |
| actual_seed = random.randint(0, 2**32 - 1) | |
| print(f"Using seed: {actual_seed} for paragraph {paragraph_index}, exaggeration: {exaggeration}") | |
| set_seed(actual_seed) | |
| generate_kwargs = { | |
| "exaggeration": exaggeration, | |
| "temperature": temperature, | |
| "cfg_weight": cfgw, | |
| } | |
| if audio_prompt_path: | |
| generate_kwargs["audio_prompt_path"] = audio_prompt_path | |
| wav = model.generate( | |
| text[:1000], # Truncate text to max chars | |
| **generate_kwargs | |
| ) | |
| audio_data = (model.sr, wav.squeeze(0).numpy()) | |
| # Save the audio file if paragraph_index is provided | |
| if paragraph_index is not None: | |
| save_audio_file(audio_data, paragraph_index) | |
| return audio_data | |
| def generate_single_internal(text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=None): | |
| """Internal generation function that does the actual work - no model parameter to avoid pickling issues""" | |
| model = get_or_load_model() # Get model inside the function | |
| if not text.strip(): | |
| return None, 0, 0 | |
| if seed_num != 0: | |
| actual_seed = int(seed_num) | |
| else: | |
| actual_seed = random.randint(0, 2**32 - 1) | |
| print(f"Using seed: {actual_seed} for paragraph {paragraph_index}, exaggeration: {exaggeration}") | |
| set_seed(actual_seed) | |
| generate_kwargs = { | |
| "exaggeration": exaggeration, | |
| "temperature": temperature, | |
| "cfg_weight": cfgw, | |
| } | |
| if audio_prompt_path: | |
| generate_kwargs["audio_prompt_path"] = audio_prompt_path | |
| wav = model.generate( | |
| text[:500], # Truncate text to max chars | |
| **generate_kwargs | |
| ) | |
| audio_data = (model.sr, wav.squeeze(0).numpy()) | |
| # Save the audio file if paragraph_index is provided | |
| if paragraph_index is not None: | |
| save_audio_file(audio_data, paragraph_index) | |
| return audio_data | |
| def generate_tts_audio( | |
| text_input: str, | |
| audio_prompt_path_input: str = None, | |
| exaggeration_input: float = 0.5, | |
| temperature_input: float = 0.8, | |
| seed_num_input: int = 0, | |
| cfgw_input: float = 0.6 | |
| ) -> tuple[int, np.ndarray]: | |
| """ | |
| Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling. | |
| This tool synthesizes natural-sounding speech from input text. When a reference audio file | |
| is provided, it captures the speaker's voice characteristics and speaking style. The generated audio | |
| maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided. | |
| Args: | |
| text_input (str): The text to synthesize into speech (maximum 300 characters) | |
| audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None. | |
| exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5. | |
| temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8. | |
| seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0. | |
| cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.6. | |
| Returns: | |
| tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray) | |
| """ | |
| current_model = get_or_load_model() | |
| if current_model is None: | |
| raise RuntimeError("TTS model is not loaded.") | |
| if seed_num_input != 0: | |
| set_seed(int(seed_num_input)) | |
| print(f"Generating audio for text: '{text_input[:50]}...'") | |
| # Handle optional audio prompt | |
| generate_kwargs = { | |
| "exaggeration": exaggeration_input, | |
| "temperature": temperature_input, | |
| "cfg_weight": cfgw_input, | |
| } | |
| if audio_prompt_path_input: | |
| generate_kwargs["audio_prompt_path"] = audio_prompt_path_input | |
| wav = current_model.generate( | |
| text_input[:300], # Truncate text to max chars | |
| **generate_kwargs | |
| ) | |
| print("Audio generation complete.") | |
| return (current_model.sr, wav.squeeze(0).numpy()) | |
| def generate_all_sequential(ref_wav_path, temperature, seed_num, cfgw, *exaggeration_values): | |
| """Generate audio for all paragraphs sequentially using individual exaggeration values""" | |
| global SECTION_INFO | |
| if not SECTION_INFO: | |
| return ["No paragraphs to process. Please process a script first."] + [None] * 100 + [0] * 100 | |
| status_messages = [] | |
| audio_results = [] | |
| # Initialize results lists with None values for all possible slots | |
| MAX_PARAGRAPHS = 50 | |
| for i in range(MAX_PARAGRAPHS): | |
| audio_results.append(None) | |
| for idx, (section_num, para_num, text) in enumerate(SECTION_INFO): | |
| if text.strip(): | |
| # Get the exaggeration value for this paragraph | |
| exaggeration = exaggeration_values[idx] if idx < len(exaggeration_values) else 0.35 | |
| print(f"Generating Section {section_num}, Paragraph {para_num} [{idx+1}/{len(SECTION_INFO)}] with exaggeration {exaggeration}...") | |
| status_messages.append(f"Processing Section {section_num}, Paragraph {para_num} (exaggeration: {exaggeration})...") | |
| # Call the spaces.GPU decorated function directly | |
| audio = generate_single_in_batch(text, ref_wav_path, exaggeration, temperature, seed_num, cfgw, paragraph_index=idx) | |
| if audio: | |
| status_messages.append(f"✓ Generated audio for Section {section_num}, Paragraph {para_num}") | |
| # Store the audio result in the correct slot | |
| audio_results[idx] = audio | |
| else: | |
| status_messages.append(f"✗ Failed to generate audio for Section {section_num}, Paragraph {para_num}") | |
| final_status = "\n".join(status_messages) + f"\n\nCompleted processing {len(SECTION_INFO)} paragraphs!" | |
| # Return status plus all audio results | |
| return [final_status] + audio_results | |
| def apply_trim_and_save(audio_data, start_time, end_time, paragraph_index): | |
| """Apply trimming and save the trimmed audio""" | |
| if audio_data is None: | |
| return None, "No audio to trim" | |
| try: | |
| trimmed_audio = trim_audio(audio_data, start_time, end_time) | |
| if trimmed_audio is None: | |
| return None, "Invalid trim parameters" | |
| # Save trimmed version | |
| filepath = save_trimmed_audio(trimmed_audio, paragraph_index, is_trimmed=True) | |
| return trimmed_audio, f"Trimmed and saved to {filepath}" | |
| except Exception as e: | |
| return None, f"Error trimming audio: {str(e)}" | |
| def update_paragraph_ui(script_text): | |
| """ | |
| Process input script and update paragraph text fields and buttons | |
| """ | |
| global SECTION_INFO, GENERATION_COUNTS | |
| # Reset generation counts | |
| GENERATION_COUNTS = {} | |
| section_info, word_count = split_text_into_sections_and_paragraphs(script_text) | |
| # Initialize generation count for each paragraph | |
| for i in range(len(section_info)): | |
| GENERATION_COUNTS[i] = 0 | |
| # Create updates for all potential paragraph fields | |
| MAX_PARAGRAPHS = 50 | |
| text_updates = [] | |
| row_updates = [] | |
| button_updates = [] | |
| audio_updates = [] | |
| label_updates = [] | |
| exaggeration_updates = [] | |
| for i in range(MAX_PARAGRAPHS): | |
| if i < len(section_info): | |
| section_num, para_num, text = section_info[i] | |
| text_updates.append(gr.update(value=text, visible=True)) | |
| row_updates.append(gr.update(visible=True)) | |
| button_updates.append(gr.update(visible=True)) | |
| audio_updates.append(gr.update(visible=True)) | |
| label_updates.append(gr.update(value=f"Section {section_num}, Paragraph {para_num} ({count_words(text)} words)", visible=True)) | |
| exaggeration_updates.append(gr.update(value=0.35, visible=True)) | |
| else: | |
| text_updates.append(gr.update(value="", visible=False)) | |
| row_updates.append(gr.update(visible=False)) | |
| button_updates.append(gr.update(visible=False)) | |
| audio_updates.append(gr.update(visible=False)) | |
| label_updates.append(gr.update(value="", visible=False)) | |
| exaggeration_updates.append(gr.update(value=0.35, visible=False)) | |
| # Update paragraph count and word count | |
| count_update = gr.update(value=f"Total Paragraphs: {len(section_info)} | Total Words: {word_count}") | |
| # Generate a summary of sections and paragraph counts | |
| section_counts = {} | |
| section_word_counts = {} | |
| for section_num, para_num, text in section_info: | |
| section_counts[section_num] = max(section_counts.get(section_num, 0), para_num) | |
| section_word_counts[section_num] = section_word_counts.get(section_num, 0) + count_words(text) | |
| section_summary = "Sections: " + ", ".join([f"Section {s}: {c} paragraphs ({section_word_counts[s]} words)" for s, c in sorted(section_counts.items())]) | |
| summary_update = gr.update(value=section_summary) | |
| # Return all updates | |
| return (text_updates + row_updates + button_updates + audio_updates + label_updates + | |
| exaggeration_updates + | |
| [count_update, summary_update]) | |
| def apply_global_exaggeration(global_value): | |
| """Apply the same exaggeration value to all visible paragraph sliders""" | |
| global SECTION_INFO | |
| updates = [] | |
| for i in range(50): # MAX_PARAGRAPHS | |
| if i < len(SECTION_INFO): | |
| # Update visible sliders with the global value | |
| updates.append(gr.update(value=global_value)) | |
| else: | |
| # Keep invisible sliders unchanged | |
| updates.append(gr.update()) | |
| return updates | |
| with gr.Blocks() as demo: | |
| # Storage for dynamic components | |
| paragraph_texts = [] | |
| paragraph_rows = [] | |
| paragraph_labels = [] | |
| generate_buttons = [] | |
| audio_outputs = [] | |
| exaggeration_sliders = [] | |
| gr.Markdown( | |
| """ | |
| # Chatterbox TTS Demo with Script Processing | |
| Generate high-quality speech from text with reference audio styling and batch processing capabilities. | |
| """ | |
| ) | |
| with gr.Tab("Single Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Textbox( | |
| value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", | |
| label="Text to synthesize (max chars 300)", | |
| max_lines=5 | |
| ) | |
| ref_wav = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Reference Audio File (Optional)", | |
| value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" | |
| ) | |
| exaggeration = gr.Slider( | |
| 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5 | |
| ) | |
| cfg_weight = gr.Slider( | |
| 0.2, 1, step=.05, label="CFG/Pace", value=0.6 | |
| ) | |
| with gr.Accordion("More options", open=False): | |
| seed_num = gr.Number(value=131789919, label="Random seed (0 for random)") | |
| temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Output Audio") | |
| run_btn.click( | |
| fn=generate_tts_audio, | |
| inputs=[ | |
| text, | |
| ref_wav, | |
| exaggeration, | |
| temp, | |
| seed_num, | |
| cfg_weight, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| with gr.Tab("Script Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| script_input = gr.Textbox( | |
| label="Script Input (section headers start with a digit, e.g., '1 Introduction', '2. Chapter Two')", | |
| placeholder="Enter your script here. Use lines starting with digits to mark sections (e.g., '1 Introduction', '2. Chapter Two', '3 - Main Content'). The first character determines the section number. Separate paragraphs with blank lines.", | |
| lines=10, | |
| value="1 Introduction\nThis is the first paragraph of section 1. It contains some interesting content.\n\nThis is the second paragraph of section 1. It's separated by a blank line.\n\n2. Chapter Two\nThis is the first paragraph of section 2. Notice the '2.' above marks the section.\n\nThis is the second paragraph of section 2.\n\n3 - Final Section\nThis is the first paragraph of section 3." | |
| ) | |
| process_btn = gr.Button("Process Script", variant="primary") | |
| paragraph_count = gr.Markdown("Total Paragraphs: 0 | Total Words: 0") | |
| section_summary = gr.Markdown("Sections: None") | |
| output_dir_info = gr.Markdown(f"Audio files will be saved to: {os.path.abspath(OUTPUT_DIR)}") | |
| # Generation parameters# Generation parameters | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Global Generation Settings") | |
| script_ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac") | |
| script_cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.6) | |
| with gr.Accordion("More options", open=False): | |
| script_seed_num = gr.Number(value=131789919, label="Random seed (0 for random)") | |
| script_temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) | |
| # NEW: Global Exaggeration Control | |
| gr.Markdown("### Global Exaggeration Control") | |
| global_exaggeration = gr.Slider( | |
| 0.25, 2, | |
| step=0.05, | |
| label="Set All Exaggeration Values", | |
| value=0.35 | |
| ) | |
| apply_global_exaggeration_btn = gr.Button("Apply to All Paragraphs", variant="secondary") | |
| generate_all_btn = gr.Button("Generate All Paragraphs", variant="primary", size="lg") | |
| generation_status = gr.Textbox(label="Generation Status", lines=5, interactive=False) | |
| # Paragraphs section | |
| gr.Markdown("## Paragraphs") | |
| # Create placeholders for paragraph entries | |
| MAX_PARAGRAPHS = 50 | |
| for i in range(MAX_PARAGRAPHS): | |
| with gr.Row(visible=False) as row: | |
| paragraph_rows.append(row) | |
| with gr.Column(scale=4): | |
| paragraph_label = gr.Markdown("Section 1, Paragraph 1", visible=False) | |
| paragraph_labels.append(paragraph_label) | |
| text_input = gr.Textbox( | |
| lines=3, | |
| max_lines=5, | |
| visible=False | |
| ) | |
| paragraph_texts.append(text_input) | |
| with gr.Column(scale=2): | |
| exaggeration_slider = gr.Slider( | |
| 0.25, 2, | |
| step=0.05, | |
| label="Exaggeration", | |
| value=0.35, | |
| visible=False | |
| ) | |
| exaggeration_sliders.append(exaggeration_slider) | |
| generate_btn = gr.Button(f"Generate", visible=False) | |
| generate_buttons.append(generate_btn) | |
| with gr.Column(scale=3): | |
| audio_output = gr.Audio(label=f"Generated Audio", type="numpy", autoplay=False, visible=False, show_download_button=True, interactive=True) | |
| audio_outputs.append(audio_output) | |
| # Setup individual paragraph generation | |
| def make_generate_handler(idx): | |
| def generate_handler(text, ref_wav, exag, temp, seed, cfg): | |
| # Call the spaces.GPU function directly without passing model | |
| audio = generate_single_internal(text, ref_wav, exag, temp, seed, cfg, idx) | |
| return audio | |
| return generate_handler | |
| generate_btn.click( | |
| fn=make_generate_handler(i), | |
| inputs=[ | |
| text_input, | |
| script_ref_wav, | |
| exaggeration_slider, | |
| script_temp, | |
| script_seed_num, | |
| script_cfg_weight, | |
| ], | |
| outputs=[audio_output], | |
| ) | |
| # Audio Merging section | |
| gr.Markdown("## Audio Merging") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Merge Options") | |
| crossfade_duration = gr.Slider( | |
| 0.0, 1.0, | |
| step=0.05, | |
| value=0.0, | |
| label="Crossfade Duration (seconds)" | |
| ) | |
| pause_duration = gr.Slider(0, 2, step=0.1, label="Pause Between Segments (seconds). Crossfade must be 0", value=0.3) | |
| with gr.Row(): | |
| merge_all_btn = gr.Button("Merge All Audio", variant="primary") | |
| section_number_input = gr.Number( | |
| value=1, | |
| label="Section Number", | |
| precision=0 | |
| ) | |
| merge_section_btn = gr.Button("Merge Section", variant="secondary") | |
| merge_status = gr.Textbox( | |
| label="Merge Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| merged_audio_output = gr.Audio( | |
| label="Merged Audio", | |
| type="numpy", | |
| show_download_button=True | |
| ) | |
| # Setup process button to update paragraphs | |
| process_btn.click( | |
| fn=update_paragraph_ui, | |
| inputs=[script_input], | |
| outputs=(paragraph_texts + paragraph_rows + generate_buttons + audio_outputs + | |
| paragraph_labels + exaggeration_sliders + | |
| [paragraph_count, section_summary]), | |
| ) | |
| # Setup generate all button - now with all audio outputs | |
| generate_all_btn.click( | |
| fn=generate_all_sequential, | |
| inputs=[ | |
| script_ref_wav, | |
| script_temp, | |
| script_seed_num, | |
| script_cfg_weight, | |
| ] + exaggeration_sliders, | |
| outputs=[generation_status] + audio_outputs, | |
| ) | |
| # Setup merge all button | |
| merge_all_btn.click( | |
| fn=merge_all_generated_audio, | |
| inputs=[crossfade_duration, pause_duration] + audio_outputs, | |
| outputs=[merged_audio_output, merge_status] | |
| ) | |
| # Setup merge section button | |
| def merge_section_handler(section_num, crossfade_duration, *audio_inputs): | |
| return merge_by_section(int(section_num), crossfade_duration, *audio_inputs) | |
| merge_section_btn.click( | |
| fn=merge_section_handler, | |
| inputs=[section_number_input, crossfade_duration] + audio_outputs, | |
| outputs=[merged_audio_output, merge_status] | |
| ) | |
| # Setup global exaggeration apply button | |
| apply_global_exaggeration_btn.click( | |
| fn=apply_global_exaggeration, | |
| inputs=[global_exaggeration], | |
| outputs=exaggeration_sliders | |
| ) | |
| demo.launch(mcp_server=True) |