import random import numpy as np import torch import gradio as gr import spaces import re from chatterbox.src.chatterbox.tts import ChatterboxTTS DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {DEVICE}") # --------------------------------------- # GLOBAL MODEL LOAD # --------------------------------------- MODEL = None def get_or_load_model(): global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxTTS.from_pretrained(DEVICE) if hasattr(MODEL, "to") and str(MODEL.device) != DEVICE: MODEL.to(DEVICE) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise return MODEL try: get_or_load_model() except Exception as e: print(f"CRITICAL startup load failed: {e}") # --------------------------------------- # UTILITIES # --------------------------------------- def set_seed(seed: int): torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) # --- SMART CHUNKING --- def smart_chunk_text(text: str, chunk_size: int): sentences = re.split(r"(?<=[\.\!\?…;])\s+", text) chunks = [] current = "" for sentence in sentences: if len(current) + len(sentence) > chunk_size: if current: chunks.append(current.strip()) current = sentence + " " else: current += sentence + " " if current: chunks.append(current.strip()) return chunks def concat_audio(chunks): if not chunks: return None return np.concatenate(chunks, axis=-1) # --------------------------------------- # MAIN TTS FUNCTION # --------------------------------------- @spaces.GPU def generate_tts_audio( text_input: str, audio_prompt_path_input: str = None, exaggeration_input: float = 0.5, temperature_input: float = 0.8, seed_num_input: int = 0, cfgw_input: float = 0.5, vad_trim_input: bool = False, enable_chunking: bool = False, chunk_size_value: int = 250, ): current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") # ------------------------- # SEED HANDLING # ------------------------- if seed_num_input == 0: used_seed = random.randint(1, 2**31 - 1) else: used_seed = int(seed_num_input) print(f"Using seed: {used_seed}") set_seed(used_seed) print(f"Generating audio for text (preview): '{text_input[:50]}...'") generate_kwargs = { "exaggeration": exaggeration_input, "temperature": temperature_input, "cfg_weight": cfgw_input, "vad_trim": vad_trim_input, } if audio_prompt_path_input: generate_kwargs["audio_prompt_path"] = audio_prompt_path_input # ------------------------- # SMART CHUNK PROCESSING # ------------------------- if enable_chunking: print(f"Smart chunking enabled — chunk size = {chunk_size_value}") text_chunks = smart_chunk_text(text_input, int(chunk_size_value)) else: text_chunks = [text_input] audio_segments = [] for i, chunk in enumerate(text_chunks): print(f"Rendering chunk {i+1}/{len(text_chunks)}...") wav = current_model.generate(chunk, **generate_kwargs) audio_segments.append(wav.squeeze(0).numpy()) final_audio = concat_audio(audio_segments) print("Audio generation complete.") # FIXED OUTPUT FORMAT (Gradio-compatible) return (current_model.sr, final_audio), used_seed # --------------------------------------- # UI # --------------------------------------- with gr.Blocks() as demo: gr.Markdown( """ # Chatterbox TTS Demo — Enhanced Version Supports unlimited text, smart chunking & random seed viewer. """ ) with gr.Row(): with gr.Column(): text = gr.Textbox( value="Now let's make my mum's favourite...", label="Text to synthesize", max_lines=10 ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File (Optional)", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" ) exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5) cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 = random)") seed_display = gr.Textbox( value="", label="Seed Used (auto-filled)", interactive=False ) temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False) enable_chunking = gr.Checkbox( label="Enable Smart Text Chunking", value=False ) chunk_size = gr.Slider( minimum=100, maximum=2000, value=250, step=10, label="Chunk Size (characters)" ) run_btn = gr.Button("Generate", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output Audio") # CONNECT BUTTON run_btn.click( fn=generate_tts_audio, inputs=[ text, ref_wav, exaggeration, temp, seed_num, cfg_weight, vad_trim, enable_chunking, chunk_size, ], outputs=[ audio_output, seed_display, ], ) demo.launch(mcp_server=True, share=True)