Chatterbox / app.py
oicui's picture
Update app.py
5329297 verified
import random
import numpy as np
import torch
import gradio as gr
import spaces
import re
from chatterbox.src.chatterbox.tts import ChatterboxTTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# ---------------------------------------
# GLOBAL MODEL LOAD
# ---------------------------------------
MODEL = None
def get_or_load_model():
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, "to") and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL startup load failed: {e}")
# ---------------------------------------
# UTILITIES
# ---------------------------------------
def set_seed(seed: int):
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# --- SMART CHUNKING ---
def smart_chunk_text(text: str, chunk_size: int):
sentences = re.split(r"(?<=[\.\!\?…;])\s+", text)
chunks = []
current = ""
for sentence in sentences:
if len(current) + len(sentence) > chunk_size:
if current:
chunks.append(current.strip())
current = sentence + " "
else:
current += sentence + " "
if current:
chunks.append(current.strip())
return chunks
def concat_audio(chunks):
if not chunks:
return None
return np.concatenate(chunks, axis=-1)
# ---------------------------------------
# MAIN TTS FUNCTION
# ---------------------------------------
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5,
vad_trim_input: bool = False,
enable_chunking: bool = False,
chunk_size_value: int = 250,
):
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
# -------------------------
# SEED HANDLING
# -------------------------
if seed_num_input == 0:
used_seed = random.randint(1, 2**31 - 1)
else:
used_seed = int(seed_num_input)
print(f"Using seed: {used_seed}")
set_seed(used_seed)
print(f"Generating audio for text (preview): '{text_input[:50]}...'")
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
"vad_trim": vad_trim_input,
}
if audio_prompt_path_input:
generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
# -------------------------
# SMART CHUNK PROCESSING
# -------------------------
if enable_chunking:
print(f"Smart chunking enabled — chunk size = {chunk_size_value}")
text_chunks = smart_chunk_text(text_input, int(chunk_size_value))
else:
text_chunks = [text_input]
audio_segments = []
for i, chunk in enumerate(text_chunks):
print(f"Rendering chunk {i+1}/{len(text_chunks)}...")
wav = current_model.generate(chunk, **generate_kwargs)
audio_segments.append(wav.squeeze(0).numpy())
final_audio = concat_audio(audio_segments)
print("Audio generation complete.")
# FIXED OUTPUT FORMAT (Gradio-compatible)
return (current_model.sr, final_audio), used_seed
# ---------------------------------------
# UI
# ---------------------------------------
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox TTS Demo — Enhanced Version
Supports unlimited text, smart chunking & random seed viewer.
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite...",
label="Text to synthesize",
max_lines=10
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 = random)")
seed_display = gr.Textbox(
value="",
label="Seed Used (auto-filled)",
interactive=False
)
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)
enable_chunking = gr.Checkbox(
label="Enable Smart Text Chunking",
value=False
)
chunk_size = gr.Slider(
minimum=100,
maximum=2000,
value=250,
step=10,
label="Chunk Size (characters)"
)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
# CONNECT BUTTON
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
vad_trim,
enable_chunking,
chunk_size,
],
outputs=[
audio_output,
seed_display,
],
)
demo.launch(mcp_server=True, share=True)