| |
| |
| """ |
| ChatTTS Gradio Application - Adapted to use ChatterboxTTS |
| Adapted for Hugging Face Spaces runtime environment |
| Based on Chatterbox-TTS implementation: https://huggingface.co/spaces/ResembleAI/Chatterbox |
| """ |
|
|
| import os |
| import random |
| import argparse |
| import torch |
| import numpy as np |
| import gradio as gr |
| from scipy.io.wavfile import write |
| import logging |
| from pathlib import Path |
| import sys |
| import re |
| import subprocess |
| from textblob import TextBlob |
| import pandas as pd |
| import base64 |
| import threading |
| import io |
| import zipfile |
| import shutil |
| import math |
| from queue import Queue |
| from pydub import AudioSegment |
| import tempfile |
| import json |
| import os |
| import sys |
| import threading |
| import time |
| import subprocess |
|
|
| from huggingface_hub import snapshot_download |
| import warnings |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| import argparse |
| |
| parser = argparse.ArgumentParser(description="IndexTTS WebUI") |
| parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode") |
| parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") |
| parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on") |
| parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory") |
| cmd_args = parser.parse_args() |
| model_dir="checkpoints" |
|
|
| |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(current_dir) |
| sys.path.append(os.path.join(current_dir, "indextts")) |
|
|
| |
| MODE = 'local' |
| snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints") |
|
|
| |
| if not os.path.exists(cmd_args.model_dir): |
| print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.") |
| sys.exit(1) |
|
|
| for file in [ |
| "bigvgan_generator.pth", |
| "bpe.model", |
| "gpt.pth", |
| "config.yaml", |
| ]: |
| file_path = os.path.join(cmd_args.model_dir, file) |
| if not os.path.exists(file_path): |
| print(f"Required file {file_path} does not exist. Please download it.") |
| sys.exit(1) |
| def analyze_sentiment(text): |
| """Use TextBlob to analyze sentiment and return optimized IndexTTS parameter dictionary""" |
| try: |
| blob = TextBlob(text) |
| polarity = blob.sentiment.polarity |
| subjectivity = blob.sentiment.subjectivity |
|
|
| |
| temperature = 1.0 + math.tanh(polarity) * 0.9 |
| |
| top_p = 0.7 + subjectivity * 0.25 |
| |
| top_k = int(20 + subjectivity * 30) |
| |
| repetition_penalty = 10.0 + (1 - subjectivity) * 5.0 - polarity * 2.0 |
| |
| length_penalty = math.tanh(polarity) * 1.0 |
| |
| num_beams = int(2 + (polarity + 1) * 1.5) |
|
|
| return { |
| "temperature": max(0.3, min(2, temperature)), |
| "top_p": max(0.7, min(0.95, top_p)), |
| "top_k": max(20, min(50, top_k)), |
| "repetition_penalty": max(5.0, min(15.0, repetition_penalty)), |
| "length_penalty": max(-1.0, min(1.0, length_penalty)), |
| "num_beams": max(2, min(5, num_beams)) |
| } |
| except Exception as e: |
| print(f"Sentiment analysis failed: {str(e)}") |
| return { |
| "temperature": 1.0, |
| "top_p": 0.8, |
| "top_k": 30, |
| "repetition_penalty": 10.0, |
| "length_penalty": 0.0, |
| "num_beams": 3 |
| } |
|
|
|
|
| |
| import gradio as gr |
| import pandas as pd |
|
|
| from indextts.infer import IndexTTS |
| from tools.i18n.i18n import I18nAuto |
|
|
| |
| i18n = I18nAuto(language="en") |
| tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml")) |
|
|
| |
| |
| |
|
|
| |
|
|
| def ensure_wav(file_path): |
| """ |
| Ensure input audio is in WAV format. |
| If not WAV, convert using ffmpeg. |
| Return WAV file path. |
| """ |
| if not file_path.lower().endswith(".wav"): |
| wav_path = file_path.rsplit(".", 1)[0] + ".wav" |
| subprocess.run(["ffmpeg", "-y", "-i", file_path, wav_path], check=True) |
| return wav_path |
| return file_path |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| last_uploaded_file = None |
|
|
| |
| cancel_generation = threading.Event() |
|
|
| |
| is_generating = threading.Event() |
| generation_queue = [] |
| queued_roles = set() |
|
|
| |
| make_bulk_functions = [None] * 40 |
|
|
| |
| audio_cache_dir = tempfile.mkdtemp(prefix='audio_cache_') |
|
|
| |
| def clear_audio_cache_dir(): |
| """Clear or recreate audio_cache_dir""" |
| global audio_cache_dir |
| try: |
| if os.path.exists(audio_cache_dir): |
| shutil.rmtree(audio_cache_dir) |
| print(f"\033[92m✅ Cleared cache folder: {audio_cache_dir}\033[0m") |
| os.makedirs(audio_cache_dir) |
| print(f"\033[92m✅ Created new cache folder: {audio_cache_dir}\033[0m") |
| except Exception as e: |
| print(f"\033[91m❌ Failed to clear cache folder: {str(e)}\033[0m") |
|
|
| |
| original_torch_load = torch.load |
|
|
| def patched_torch_load(f, map_location=None, **kwargs): |
| """ |
| Patched torch.load, automatically maps CUDA tensors to CPU/CUDA |
| """ |
| if map_location is None: |
| map_location = 'cpu' |
| logger.info(f"🔧 Loading with map_location={map_location}") |
| return original_torch_load(f, map_location=map_location, **kwargs) |
|
|
| |
| torch.load = patched_torch_load |
| if 'torch' in sys.modules: |
| sys.modules['torch'].load = patched_torch_load |
|
|
| logger.info("✅ Successfully applied torch.load device mapping patch") |
|
|
| |
| if torch.cuda.is_available(): |
| DEVICE = "cuda" |
| logger.info("🚀 Running with CUDA GPU") |
| else: |
| DEVICE = "cpu" |
| logger.info("🚀 Running with CPU") |
|
|
| print(f"🚀 Running on device: {DEVICE}") |
|
|
| |
| MODEL = True |
|
|
| def get_or_load_model(): |
| """Load ChatterboxTTS model (if not already loaded), and ensure it runs on the correct device""" |
| global MODEL, DEVICE |
| if MODEL is None: |
| print("Model not loaded, initializing...") |
| try: |
| |
| try: |
| from chatterbox.src.chatterbox.tts import ChatterboxTTS |
| logger.info("✅ Using official chatterbox.src import path") |
| except ImportError: |
| |
| from chatterbox import ChatterboxTTS |
| logger.info("✅ Using chatterbox direct import path") |
| |
| |
| MODEL = ChatterboxTTS.from_pretrained("cpu") |
| |
| |
| if DEVICE == "cuda": |
| logger.info(f"Moving model components to {DEVICE}...") |
| try: |
| if hasattr(MODEL, 't3'): |
| MODEL.t3 = MODEL.t3.to(DEVICE) |
| if hasattr(MODEL, 's3gen'): |
| MODEL.s3gen = MODEL.s3gen.to(DEVICE) |
| if hasattr(MODEL, 've'): |
| MODEL.ve = MODEL.ve.to(DEVICE) |
| |
| MODEL.device = DEVICE |
| logger.info(f"✅ All model components moved to {DEVICE}") |
| except Exception as e: |
| logger.warning(f"⚠️ Unable to move some components to {DEVICE}: {e}") |
| logger.info("🔄 Falling back to CPU mode for stability") |
| DEVICE = "cpu" |
| MODEL.device = "cpu" |
| |
| logger.info(f"✅ Model loaded on {DEVICE}") |
| return MODEL, "Model loaded successfully" |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to load model: {e}") |
| return None, f"Model loading failed: {str(e)}" |
| return MODEL, "Model already loaded" |
|
|
| def parse_role_mappings(mapping_text): |
| """Parse role mapping text and return {role_code: {'display_name': display_name, 'voice_file': 'xxx.wav', 'seed_id': seed_id}} dict""" |
| role_map = {} |
| if not mapping_text: |
| return role_map |
| lines = mapping_text.strip().splitlines() |
| pattern = re.compile(r'Role Code:\s*(\w+)\s*→\s*Display Name:\s*([^→]+?)(?:\s+voice:\s*([\w_-]+))?(?:\s+seed_id:\s*(\w+))?\s*$') |
| |
| for line in lines: |
| match = pattern.match(line.strip()) |
| if match: |
| role_code = match.group(1).strip() |
| display_name = match.group(2).strip() |
| voice_name = match.group(3).strip() if match.group(3) else None |
| seed_id = match.group(4).strip() if match.group(4) else None |
| |
| voice_file = None |
| if voice_name: |
| voice_file = voice_name + ".wav" |
| |
| role_map[role_code] = {'display_name': display_name, 'voice_file': voice_file, 'seed_id': seed_id} |
| |
| return role_map |
|
|
| def filter_lines(lines, filter_text): |
| """Filter dialogue lines according to filter conditions""" |
| if not filter_text: |
| return lines |
| filtered = [] |
| conditions = filter_text.strip().split(';') |
| for line in lines: |
| include = True |
| for condition in conditions: |
| condition = condition.strip() |
| if not condition: |
| continue |
| if condition.startswith('min_len='): |
| try: |
| min_len = int(condition.split('=')[1]) |
| if len(re.sub(r'[^A-Za-z0-9]', '', line)) < min_len: |
| include = False |
| except ValueError: |
| print(f"Warning: invalid min_len condition: {condition}") |
| elif condition.startswith('keyword='): |
| keyword = condition.split('=')[1].lower() |
| if keyword not in line.lower(): |
| include = False |
| else: |
| print(f"Warning: unknown filter condition: {condition}") |
| if include: |
| filtered.append(line) |
| return filtered |
|
|
| def parse_roles_from_rpy_file(file_bytes, filter_text=""): |
| try: |
| lines = file_bytes.decode("utf-8").splitlines() |
| role_pattern = re.compile(r'^#\s*(?:(\w+)\s*)?\"(.*?)\"') |
| role_dict = {} |
|
|
| for line in lines: |
| match = role_pattern.match(line.strip()) |
| if match: |
| role = match.group(1) if match.group(1) else 'noname' |
| content = match.group(2) |
| if '(' in content or ')' in content: |
| continue |
| content = re.sub(r'\[.*?\]', ',', content) |
| content = re.sub(r'\{.*?\}', '', content) |
| content = re.sub(r'\\', '', content) |
| content = re.sub(r'\*.*?\*', '', content) |
| content = re.sub(r'\s+', ' ', content).strip() |
| if len(content) < 2: |
| continue |
| if role not in role_dict: |
| role_dict[role] = [] |
| role_dict[role].append(content) |
|
|
| for role in role_dict: |
| role_dict[role] = filter_lines(role_dict[role], filter_text) |
|
|
| return role_dict, lines |
| except Exception as e: |
| raise ValueError(f"Failed to parse .rpy file: {str(e)}") |
|
|
| def save_wav(audio_data, sample_rate, output_path): |
| try: |
| if audio_data.dtype != np.int16: |
| audio_int16 = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16) |
| else: |
| audio_int16 = audio_data |
| write(output_path, sample_rate, audio_int16) |
| except Exception as e: |
| raise RuntimeError(f"Failed to save audio file: {str(e)}") |
|
|
| |
| def pre_analyze_lines(texts): |
| """Analyze sentiment for each text line and return a list of parameter dicts""" |
| context_window = 2 |
| params_list = [] |
| for idx, line in enumerate(texts): |
| if not line.strip(): |
| params_list.append({ |
| "temperature": 1.0, |
| "top_p": 0.8, |
| "top_k": 30, |
| "repetition_penalty": 10.0, |
| "length_penalty": 0.0, |
| "num_beams": 3 |
| }) |
| continue |
| start_idx = max(0, idx - context_window) |
| end_idx = min(len(texts), idx + context_window + 1) |
| context_text = " ".join(texts[start_idx:end_idx]) |
| params = analyze_sentiment(context_text) |
| params_list.append(params) |
| return params_list |
|
|
| def get_pt_file(seed_id, csv_path=os.path.join(os.path.dirname(__file__), "evaluation_results.csv")): |
| """Get .pt data from CSV file by seed_id and load as PyTorch tensor""" |
| try: |
| if seed_id and not seed_id.startswith("seed_"): |
| seed_id = f"seed_{seed_id}" |
| |
| df = pd.read_csv(csv_path, encoding="utf-8") |
| row = df[df["seed_id"] == seed_id] |
| if row.empty: |
| return None, f"seed_id not found: {seed_id}" |
| |
| emb_data = row.iloc[0]["emb_data"] |
| emb_bytes = base64.b64decode(emb_data) |
| emb_buffer = io.BytesIO(emb_bytes) |
| spk_emb = torch.load(emb_buffer) |
| |
| return spk_emb, f"Successfully loaded spk_emb data for seed_id: {seed_id}" |
| except Exception as e: |
| return None, f"Failed to load spk_emb data: {str(e)}" |
|
|
| def zip_outputs_folder(): |
| try: |
| outputs_dir = "Outputs" |
| zip_path = os.path.join("tmp", "outputs.zip") |
| os.makedirs("tmp", exist_ok=True) |
|
|
| if not os.path.exists(outputs_dir) or not os.path.isdir(outputs_dir): |
| return "Outputs folder does not exist or is empty", gr.DownloadButton(visible=False) |
|
|
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
| for root, _, files in os.walk(outputs_dir): |
| for file in files: |
| file_path = os.path.join(root, file) |
| arcname = os.path.relpath(file_path, outputs_dir) |
| zipf.write(file_path, os.path.join("Outputs", arcname)) |
|
|
| return ( |
| "outputs.zip created successfully", |
| gr.DownloadButton( |
| value=zip_path, |
| label="Download Outputs.zip", |
| visible=True |
| ) |
| ) |
| except Exception as e: |
| return f"Zip failed: {str(e)}", gr.DownloadButton(visible=False) |
|
|
| def compress_wav_to_mp3(): |
| """Convert .wav files in Outputs folder and subfolders to 80kbps .mp3 files, replacing the originals""" |
| try: |
| outputs_dir = "Outputs" |
| if not os.path.exists(outputs_dir) or not os.path.isdir(outputs_dir): |
| return "Outputs folder does not exist or is empty" |
| |
| converted_files = 0 |
| for root, _, files in os.walk(outputs_dir): |
| for file in files: |
| if file.lower().endswith(".wav"): |
| wav_path = os.path.join(root, file) |
| mp3_path = os.path.join(root, file[:-4] + ".mp3") |
| try: |
| |
| audio = AudioSegment.from_wav(wav_path) |
| |
| audio.export(mp3_path, format="mp3", bitrate="80k") |
| |
| os.remove(wav_path) |
| print(f"\033[92m✅ Converted and replaced: {wav_path} -> {mp3_path}\033[0m") |
| converted_files += 1 |
| except Exception as e: |
| print(f"\033[91m❌ Failed to convert {wav_path}: {str(e)}\033[0m") |
| continue |
| |
| if converted_files == 0: |
| return "No .wav files found to convert" |
| return f"Successfully converted and replaced {converted_files} .wav files to .mp3" |
| except Exception as e: |
| print(f"\033[91m❌ Failed to compress WAV to MP3: {str(e)}\033[0m") |
| return f"Failed to compress WAV to MP3: {str(e)}" |
|
|
|
|
| def generate_all_lines_audio(role, texts, lines, start_index, role_map, spk_emb_seed_id=None, ref_wav=None, role_index=0, button=None, status_output=None, use_sentiment=True): |
| """Batch generate role audio, dynamically adjust IndexTTS parameters based on sentiment analysis""" |
| try: |
| display_name = role_map.get(role, {'display_name': role})['display_name'] |
| |
| if cancel_generation.is_set(): |
| return None, f"{display_name}: Generation cancelled", button |
|
|
| |
| if not spk_emb_seed_id and not ref_wav: |
| print(f"\033[93m⚠️ Role {display_name} did not provide spk_emb_seed_id or reference audio, skipping generation\033[0m") |
| return None, f"{display_name}: No spk_emb_seed_id or reference audio provided, skipping generation", gr.Button(interactive=True) |
|
|
| current_model = get_or_load_model() |
| if current_model[0] is None: |
| raise RuntimeError(f"ChatterboxTTS model not loaded: {current_model[1]}") |
|
|
| params_list = pre_analyze_lines(texts[start_index-1:]) |
| saved_paths = [] |
| os.makedirs("Outputs", exist_ok=True) |
|
|
| role_pattern = re.compile(r'^#\s*(?:' + re.escape(role) + r'\s*)?\"(.*?)\"') |
| translate_pattern = re.compile(r'^translate\s+\w+\s+([\w_-]+)\s*:') |
|
|
| print(f"\033[94m🚀 Start generating audio for role {display_name}, from line {start_index}, total {len(texts[start_index-1:])} lines\033[0m") |
|
|
| def generate_single_line(idx, line, params): |
| if cancel_generation.is_set(): |
| return None |
| if not line.strip(): |
| return None |
| text_index = None |
| for i, file_line in enumerate(lines): |
| if role_pattern.match(file_line.strip()): |
| raw_content = role_pattern.match(file_line.strip()).group(1) |
| raw_content = re.sub(r'\[.*?\]', ',', raw_content) |
| raw_content = re.sub(r'\{.*?\}', '', raw_content) |
| raw_content = re.sub(r'\\', '', raw_content) |
| raw_content = re.sub(r'\*.*?\*', '', raw_content) |
| if line == raw_content.strip(): |
| text_index = i |
| break |
| identifier = None |
| if text_index is not None: |
| for j in range(1, 5): |
| if text_index - j >= 0: |
| line_to_check = lines[text_index - j].strip() |
| translate_match = translate_pattern.match(line_to_check) |
| if translate_match: |
| identifier = translate_match.group(1) |
| break |
| if identifier is None: |
| print(f"\033[93mSkip: No translate id found for {line} [Role: {display_name}, Index: {idx+1}]\033[0m") |
| return None |
|
|
| processed_line = line |
| |
|
|
| filename = f"{identifier}.wav" |
| output_path = os.path.join("Outputs", filename) |
|
|
| |
| kwargs = { |
| "do_sample": True, |
| "top_p": 0.8, |
| "top_k": 30, |
| "temperature": 1.0, |
| "length_penalty": 0.0, |
| "num_beams": 3, |
| "repetition_penalty": 10.0, |
| "max_mel_tokens": 600 |
| } |
|
|
| |
| if use_sentiment: |
| kwargs.update(params) |
|
|
| |
| output = tts.infer(ref_wav, processed_line, output_path, verbose=cmd_args.verbose, |
| max_text_tokens_per_sentence=120, **kwargs) |
|
|
| print(f"\033[91m-------\nTemperature: {kwargs['temperature']:.2f}, Top_p: {kwargs['top_p']:.2f}, Top_k: {kwargs['top_k']}, " |
| f"RP: {kwargs['repetition_penalty']:.2f}, Length_penalty: {kwargs['length_penalty']:.2f}, " |
| f"Num_beams: {kwargs['num_beams']} {line}\n[Role: {display_name}: {idx+1}/{len(texts)}]{line}\n----------\033[0m") |
| return output_path |
|
|
| for idx, (line, params) in enumerate(zip(texts[start_index-1:], params_list), start=start_index-1): |
| if cancel_generation.is_set(): |
| print(f"\033[93m⚠️ {display_name}: Generation cancelled, stopped at line {idx+1}\033[0m") |
| return None, f"{display_name}: Generation cancelled", button |
| result = generate_single_line(idx, line, params) |
| if result: |
| saved_paths.append(result) |
|
|
| print(f"\033[94m✅ {display_name}: Generation finished, {len(saved_paths)} audio lines generated\033[0m") |
| return None, f"{display_name}: From line {start_index}, {len(saved_paths)} audio lines generated", gr.Button(interactive=True) |
| except Exception as e: |
| print(f"\033[91m❌ {display_name}: Batch generation failed: {str(e)}\033[0m") |
| return None, f"{display_name}: Batch generation failed: {str(e)}", gr.Button(interactive=True) |
| finally: |
| queued_roles.discard(role_index) |
| is_generating.clear() |
| process_queue() |
|
|
| def process_queue(): |
| """Process the next role in the queue""" |
| global generation_queue, is_generating, queued_roles |
| while generation_queue and not is_generating.is_set() and not cancel_generation.is_set(): |
| role_config = generation_queue.pop(0) |
| role_index = role_config['role_index'] |
| role_data = role_texts[role_index] |
| if not role_data: |
| continue |
| role = role_data.get("role", f"Role{role_index+1}") |
| display_name = role_data.get("display_name", role) |
| txt = role_config['txt'] |
| start_idx = role_config['start_idx'] |
| spk_emb_seed_id = role_config['spk_emb_seed_id'] |
| ref_wav = role_config['ref_wav'] |
| |
| |
| if not spk_emb_seed_id and not ref_wav: |
| print(f"\033[93m⚠️ Role {display_name} did not provide spk_emb_seed_id or reference audio, skipping generation\033[0m") |
| role_components[role_index][4].value = f"{display_name}: No spk_emb_seed_id or reference audio provided, skipping generation" |
| role_components[role_index][-1].value = gr.Button(interactive=True) |
| queued_roles.discard(role_index) |
| continue |
| |
| print(f"\033[94m🚀 Starting generation for role {display_name} from queue\033[0m") |
| |
| role_components[role_index][4].value = f"{display_name}: Generation started" |
| role_components[role_index][0].value = txt |
| role_components[role_index][1].value = start_idx |
| role_components[role_index][2].value = spk_emb_seed_id |
| role_components[role_index][3].value = ref_wav |
| is_generating.set() |
| |
| file_lines = [] |
| if last_uploaded_file: |
| try: |
| _, file_lines = parse_roles_from_rpy_file(last_uploaded_file, "") |
| except Exception as e: |
| role_components[role_index][4].value = f"File parsing failed: {str(e)}" |
| role_components[role_index][-1].value = gr.Button(interactive=True) |
| queued_roles.discard(role_index) |
| is_generating.clear() |
| continue |
| role_map = parse_role_mappings(role_mapping_input.value) |
| result = generate_all_lines_audio( |
| role, role_data.get("lines", []), file_lines, int(start_idx), role_map, |
| spk_emb_seed_id, ref_wav, role_index=role_index, |
| button=role_components[role_index][-1], status_output=role_components[role_index][4] |
| ) |
| |
| role_components[role_index][4].value = result[1] |
| role_components[role_index][-1].value = result[2] |
| break |
|
|
| role_texts = [{} for _ in range(40)] |
| role_components = [] |
| role_mapping_input = None |
|
|
| def process_file(file, mapping_text, filter_text): |
| global last_uploaded_file |
| last_uploaded_file = file |
| if file is None: |
| return ["", 1, "", None, "", f"### Role {i+1}", False, True] * 40 + ["Please upload a file"] |
|
|
| try: |
| mapping_file_path = os.path.join(os.path.dirname(__file__), "role_mapping.txt") |
| with open(mapping_file_path, "w", encoding="utf-8") as f: |
| f.write(mapping_text.strip() if mapping_text else "") |
|
|
| role_map = parse_role_mappings(mapping_text) |
| role_dict, lines = parse_roles_from_rpy_file(file, filter_text) |
| |
| filtered_role_dict = {} |
| for role, lines_ in role_dict.items(): |
| if len(lines_) >= 100: |
| filtered_role_dict[role] = lines_ |
| else: |
| print(f"Role {role} dialogue count ({len(lines_)}) is less than 100, filtered out") |
| |
| |
| role_items = [(role, lines_) for role, lines_ in filtered_role_dict.items()] |
| role_items.sort(key=lambda x: role_map.get(x[0], {'display_name': x[0]})['display_name'].lower()) |
| |
|
|
| except Exception as e: |
| return ["", 1, "", None, "", f"### Role {i+1}", False, True] * 40 + [f"File parsing failed: {str(e)}"] |
|
|
| return_values = [] |
| for i in range(40): |
| if i < len(role_items): |
| role, lines_ = role_items[i] |
| display_name = role_map.get(role, {'display_name': role, 'voice_file': None, 'seed_id': None})['display_name'] |
| seed_id = role_map.get(role, {'display_name': role, 'voice_file': None, 'seed_id': None})['seed_id'] |
| voice_file = role_map.get(role, {'display_name': role, 'voice_file': None, 'seed_id': None})['voice_file'] |
| ref_wav_original = None |
| if voice_file: |
| ref_wav_original = os.path.join('voice', voice_file) |
| |
| if not ref_wav_original or not os.path.exists(ref_wav_original): |
| ref_wav_original = os.path.join('voice', display_name + '.wav') |
| if ref_wav_original and os.path.exists(ref_wav_original): |
| |
| cached_path = os.path.join(audio_cache_dir, f"role_{i}_{display_name}.wav") |
| shutil.copy(ref_wav_original, cached_path) |
| ref_wav_path = cached_path |
| print(f"\033[92m✅ Loaded and cached reference audio for {display_name}: {ref_wav_original} -> {cached_path}\033[0m") |
| else: |
| ref_wav_path = None |
| |
| role_texts[i] = {"role": role, "display_name": display_name, "lines": lines_, "ref_wav": ref_wav_path, "seed_id": seed_id} |
| joined = "\n".join(lines_) |
| return_values.extend([joined, 1, seed_id, ref_wav_path, f"{display_name}: {len(lines_)} lines in total", f"### {display_name}", True, True]) |
| else: |
| role_texts[i] = {} |
| return_values.extend(["", 1, "", None, "", f"### Role {i+1}", False, True]) |
|
|
| return_values.append("File processed successfully, please click batch generate for each role") |
| return return_values |
|
|
| def stop_all_generation(): |
| """Stop all ongoing generation tasks and release all buttons""" |
| global generation_queue, is_generating, queued_roles |
| cancel_generation.set() |
| generation_queue.clear() |
| queued_roles.clear() |
| is_generating.clear() |
| print(f"\033[93m🛑 Stopped all generation tasks\033[0m") |
| button_states = [gr.Button(interactive=True) for _ in range(40)] |
| status_outputs = [f"{role_texts[i].get('display_name', f'Role {i+1}')} : Generation cancelled" if role_texts[i] else "" for i in range(40)] |
| |
| for i in range(40): |
| if role_components[i][-1]: |
| role_components[i][-1].value = gr.Button(interactive=True) |
| return "Stopping all generation tasks...", status_outputs, button_states |
|
|
| def get_pt_file_for_download(seed_id): |
| """Generate .pt file for download button and return path and status""" |
| spk_emb, message = get_pt_file(seed_id) |
| if spk_emb is not None: |
| os.makedirs("tmp", exist_ok=True) |
| output_path = os.path.join("tmp", f"{seed_id}_restored_emb.pt") |
| torch.save(spk_emb, output_path) |
| return gr.DownloadButton(value=output_path, label=f"Download .pt File [{seed_id}]", visible=True), message |
| return gr.DownloadButton(value=None, label="Download .pt File", visible=False), message |
|
|
| def get_download_zip_state(): |
| """Check zip file state and return download button state""" |
| if os.path.exists(os.path.join("tmp", "outputs.zip")): |
| return gr.DownloadButton(value=os.path.join("tmp", "outputs.zip"), label="Download Outputs.zip", visible=True), "Ready to download outputs.zip" |
| return gr.DownloadButton(value=None, label="Download Outputs.zip", visible=False), "outputs.zip not found" |
|
|
| def main(): |
| |
| |
|
|
| global role_components, make_bulk_functions, audio_cache_dir, role_mapping_input |
| MAX_ROLES = 40 |
| mapping_file_path = os.path.join(os.path.dirname(__file__), "role_mapping.txt") |
| try: |
| with open(mapping_file_path, "r", encoding="utf-8") as f: |
| role_mapping_placeholder = f.read().strip() |
| except FileNotFoundError: |
| role_mapping_placeholder = ( |
| "Role Code: jud → Display Name: Judge voice:jud seed_id:1403\n" |
| "Role Code: jury → Display Name: Members of The Jury voice:jury seed_id:1404\n" |
| "Role Code: noname → Display Name: Anonymous Role" |
| ) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("https://huggingface.co/spaces/ResembleAI/Chatterbox") |
| |
| model_status = gr.Textbox(label="Model loading status", interactive=False, value="Loading model...") |
|
|
| with gr.Row(): |
| default_rpy_path = os.path.join(os.path.dirname(__file__), "dialogue.rpy") |
| file_input = gr.File( |
| label="Upload .rpy file", |
| type="binary", |
| value=default_rpy_path if os.path.exists(default_rpy_path) else None |
| ) |
| with gr.Column(): |
| role_mapping_input = gr.Textbox( |
| label="Role Mapping (Format: Role Code: xxx → Display Name: xxx voice:xxx seed_id:xxx)", |
| lines=5, |
| placeholder=role_mapping_placeholder, |
| value=role_mapping_placeholder |
| ) |
| filter_input = gr.Textbox( |
| label="Dialogue Filter (Format: min_len=X;keyword=Y, leave empty to show all)", |
| lines=2, |
| placeholder="Example: min_len=20;keyword=hello\n(Minimum length 20 characters, contains 'hello' dialogues)", |
| value="min_len=3" |
| ) |
| with gr.Row(): |
| process_button = gr.Button("Process File") |
| stop_button = gr.Button("Stop All Generation") |
| zip_button = gr.Button("Zip Outputs Folder") |
| compress_button = gr.Button("Compress WAV to MP3") |
| all_status = gr.Textbox(label="Overall Status", interactive=False, visible=False) |
| download_zip_button = gr.DownloadButton(label="Download Outputs.zip", visible=False) |
| zip_status = gr.Textbox(label="Zip Status", interactive=False) |
|
|
| with gr.Row(): |
| seed_id_input = gr.Textbox( |
| label="Enter seed_id to get .pt file", |
| placeholder="e.g.: 1403 or seed_1403", |
| visible=False |
| ) |
| pt_download_button = gr.DownloadButton(label="Download .pt File", visible=False) |
| pt_status = gr.Textbox(label="Generate .pt file status", interactive=False) |
|
|
| def load_model_on_start(): |
| model, status = get_or_load_model() |
| return status |
|
|
| demo.load( |
| fn=load_model_on_start, |
| inputs=[], |
| outputs=[model_status] |
| ) |
|
|
| zip_button.click( |
| fn=zip_outputs_folder, |
| inputs=[], |
| outputs=[zip_status, download_zip_button] |
| ) |
|
|
| compress_button.click( |
| fn=compress_wav_to_mp3, |
| inputs=[], |
| outputs=[zip_status] |
| ) |
|
|
|
|
| seed_id_input.change( |
| fn=get_pt_file_for_download, |
| inputs=[seed_id_input], |
| outputs=[pt_download_button, pt_status] |
| ) |
|
|
| role_components = [] |
| visibility_states = [gr.State(value=False) for _ in range(MAX_ROLES)] |
| button_states = [gr.State(value=True) for _ in range(MAX_ROLES)] |
|
|
| ROLES_PER_ROW = 4 |
| for row_idx in range((MAX_ROLES + ROLES_PER_ROW - 1) // ROLES_PER_ROW): |
| with gr.Row(): |
| for i in range(row_idx * ROLES_PER_ROW, min((row_idx + 1) * ROLES_PER_ROW, MAX_ROLES)): |
| with gr.Group(visible=False, elem_classes="compact-group") as group: |
| role_display = gr.Markdown(f"### Role {i+1}", elem_classes="compact-header") |
| text_input = gr.Textbox( |
| label="Text", |
| lines=2, |
| max_lines=4, |
| elem_classes="compact-textbox", |
| container=False |
| ) |
| start_index = gr.Number( |
| value=1, |
| label="Start from line", |
| minimum=1, |
| step=1, |
| elem_classes="compact-number" |
| ) |
| spk_emb_seed_id = gr.Textbox( |
| label="Enter spk_emb seed_id", |
| placeholder="e.g.: 1403 or seed_1403", |
| elem_classes="compact-textbox", |
| visible=False |
| ) |
| ref_wav = gr.Audio( |
| type="filepath", |
| label="Reference audio file (optional, recommended >6s)", |
| sources=["upload", "microphone"], |
| elem_classes="compact-audio" |
| ) |
| audio_output = gr.Audio( |
| label="Output audio", |
| elem_classes="compact-audio", |
| visible=False |
| ) |
| status = gr.Textbox( |
| label="Status", |
| interactive=False, |
| elem_classes="compact-textbox", |
| container=False |
| ) |
| bulk_btn = gr.Button( |
| f"Generate Role {i+1}", |
| size="sm", |
| elem_classes="compact-button", |
| interactive=True |
| ) |
|
|
| role_components.append([text_input, start_index, spk_emb_seed_id, ref_wav, status, role_display, bulk_btn]) |
|
|
| def make_bulk(i, bulk_btn): |
| def inner(txt, start_idx, spk_emb_seed_id, ref_wav, button): |
| global generation_queue, is_generating, queued_roles |
| role_data = role_texts[i] |
| if not role_data: |
| return None, f"Role {i+1} has no data", gr.Button(interactive=True) |
| role = role_data.get("role", f"Role{i+1}") |
| display_name = role_data.get("display_name", role) |
| |
| |
| ref_wav_path = ref_wav |
| if ref_wav and os.path.exists(ref_wav): |
| |
| if not ref_wav.startswith(audio_cache_dir): |
| cached_path = os.path.join(audio_cache_dir, f"role_{i}_{display_name}_ui.wav") |
| shutil.copy(ref_wav, cached_path) |
| role_texts[i]["ref_wav"] = cached_path |
| print(f"\033[92m✅ Cached UI provided reference audio for {display_name}: {ref_wav} -> {cached_path}\033[0m") |
| else: |
| |
| role_texts[i]["ref_wav"] = ref_wav |
| print(f"\033[92m✅ Using already cached reference audio for {display_name}: {ref_wav}\033[0m") |
| else: |
| ref_wav_path = role_texts[i].get("ref_wav") |
| |
| |
| if not spk_emb_seed_id and not ref_wav_path: |
| print(f"\033[93m⚠️ Role {display_name} did not provide spk_emb_seed_id or reference audio, skipping generation\033[0m") |
| return None, f"{display_name}: No spk_emb_seed_id or reference audio provided, skipping generation", gr.Button(interactive=True) |
| |
| |
| role_config = { |
| 'role_index': i, |
| 'txt': txt, |
| 'start_idx': int(start_idx), |
| 'spk_emb_seed_id': spk_emb_seed_id, |
| 'ref_wav': ref_wav_path |
| } |
| |
| if is_generating.is_set(): |
| if i not in queued_roles: |
| generation_queue.append(role_config) |
| queued_roles.add(i) |
| pos = len(generation_queue) |
| print(f"\033[94m⏳ Role {display_name} added to queue, position {pos}/{pos}\033[0m") |
| return None, f"{display_name}: Waiting in queue, position {pos}/{pos}", gr.Button(interactive=False) |
| else: |
| is_generating.set() |
| queued_roles.add(i) |
| cancel_generation.clear() |
| print(f"\033[94m🚀 Single role generation: {display_name}, from line {start_idx}\033[0m") |
| file_lines = [] |
| if last_uploaded_file: |
| try: |
| _, file_lines = parse_roles_from_rpy_file(last_uploaded_file, "") |
| except Exception as e: |
| return None, f"File parsing failed: {str(e)}", gr.Button(interactive=True) |
| role_map = parse_role_mappings(role_mapping_input.value) |
| result = generate_all_lines_audio( |
| role, role_data.get("lines", []), file_lines, int(start_idx), role_map, |
| spk_emb_seed_id, ref_wav_path, role_index=i, button=button, status_output=status |
| ) |
| return result |
| return inner |
|
|
| |
| make_bulk_functions[i] = make_bulk(i, bulk_btn) |
|
|
| bulk_btn.click( |
| fn=make_bulk_functions[i], |
| inputs=[text_input, start_index, spk_emb_seed_id, ref_wav, bulk_btn], |
| outputs=[audio_output, status, bulk_btn] |
| ) |
|
|
| visibility_states[i].change( |
| fn=lambda x, g=group: {"__type__": "update", "visible": x}, |
| inputs=visibility_states[i], |
| outputs=group |
| ) |
|
|
| outputs = [] |
| for i in range(MAX_ROLES): |
| for comp in role_components[i][:-1]: |
| outputs.append(comp) |
| outputs.append(visibility_states[i]) |
| outputs.append(role_components[i][-1]) |
| outputs.append(all_status) |
|
|
| process_button.click( |
| fn=process_file, |
| inputs=[file_input, role_mapping_input, filter_input], |
| outputs=outputs |
| ) |
|
|
| stop_button.click( |
| fn=stop_all_generation, |
| inputs=[], |
| outputs=[all_status, *[comp[4] for comp in role_components], *button_states] |
| ) |
|
|
| demo.css = """ |
| body { |
| background-color: #808080 !important; |
| } |
| .compact-group { |
| width: 250px !important; |
| min-width: 230px !important; |
| padding: 5px !important; |
| margin: 5px !important; |
| } |
| .compact-header { |
| font-size: 14px !important; |
| margin: 2px 0 !important; |
| } |
| .compact-textbox { |
| font-size: 12px !important; |
| line-height: 1.2 !important; |
| padding: 2px !important; |
| margin: 2px 0 !important; |
| } |
| .compact-number { |
| width: 80px !important; |
| font-size: 12px !important; |
| padding: 2px !important; |
| margin: 2px 0 !important; |
| } |
| .compact-button { |
| font-size: 12px !important; |
| padding: 4px !important; |
| margin: 2px 0 !important; |
| } |
| .compact-audio { |
| max-width: 220px !important; |
| min-width: 200px !important; |
| font-size: 12px !important; |
| margin: 2px 0 !important; |
| overflow: visible !important; |
| } |
| """ |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| args = parser.parse_args() |
|
|
| os.environ["GRADIO_SERVER_NAME"] = args.host |
| os.environ["GRADIO_SERVER_PORT"] = str(args.port) |
|
|
| demo.launch(server_name=args.host, server_port=args.port, share=False) |
|
|
| if __name__ == '__main__': |
| main() |
| |