Chatterbox / app.py
n8n485's picture
Update app.py
6190e8e verified
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces
# ─── Global patch to fix CUDA deserialization error on CPU ───
# This forces map_location='cpu' on all torch.load calls when CUDA is unavailable
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
if 'map_location' not in kwargs and not torch.cuda.is_available():
kwargs['map_location'] = torch.device('cpu')
return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
# ─────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
def get_or_load_model():
"""Loads the ChatterboxTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
# On CPU, .to(DEVICE) is usually redundant after loading with map_location
# but we keep it for safety / future GPU support
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup (helps catch errors early in logs)
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
@spaces.GPU # harmless on CPU, ignored by HF when no GPU is allocated
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5,
vad_trim_input: bool = False,
) -> tuple[int, np.ndarray]:
"""
Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
# Handle optional audio prompt
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
"vad_trim": vad_trim_input,
}
if audio_prompt_path_input:
generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
wav = current_model.generate(
text_input[:300], # Truncate text to max chars
**generate_kwargs
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox TTS Demo
Generate high-quality speech from text with reference audio styling.
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="Text to synthesize (max chars 300)",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
exaggeration = gr.Slider(
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05, label="CFG/Pace", value=0.5
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
vad_trim,
],
outputs=[audio_output],
)
demo.launch(mcp_server=True)