Spaces:
Sleeping
Sleeping
File size: 5,444 Bytes
9d593b2 9386371 9d593b2 0b3e025 9d593b2 6190e8e 9d593b2 0b3e025 9386371 0b3e025 9386371 0b3e025 6190e8e 9386371 0b3e025 9386371 0b3e025 9386371 0b3e025 6190e8e 0b3e025 9386371 9d593b2 9386371 9d593b2 0b3e025 3dab9c0 9d593b2 6190e8e 9386371 af25078 1afd111 9386371 af25078 9386371 0b3e025 9386371 0b3e025 9386371 1afd111 af25078 1afd111 af25078 1afd111 af25078 1afd111 0b3e025 9386371 af25078 9d593b2 3dab9c0 0b3e025 9d593b2 9386371 9d593b2 9386371 bf4bbc3 9386371 58ffee2 9d593b2 9386371 1afd111 9d593b2 9386371 9d593b2 58ffee2 1afd111 9d593b2 9386371 9d593b2 6190e8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces
# βββ Global patch to fix CUDA deserialization error on CPU βββ
# This forces map_location='cpu' on all torch.load calls when CUDA is unavailable
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
if 'map_location' not in kwargs and not torch.cuda.is_available():
kwargs['map_location'] = torch.device('cpu')
return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"π Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
def get_or_load_model():
"""Loads the ChatterboxTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
# On CPU, .to(DEVICE) is usually redundant after loading with map_location
# but we keep it for safety / future GPU support
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup (helps catch errors early in logs)
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
@spaces.GPU # harmless on CPU, ignored by HF when no GPU is allocated
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5,
vad_trim_input: bool = False,
) -> tuple[int, np.ndarray]:
"""
Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
# Handle optional audio prompt
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
"vad_trim": vad_trim_input,
}
if audio_prompt_path_input:
generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
wav = current_model.generate(
text_input[:300], # Truncate text to max chars
**generate_kwargs
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox TTS Demo
Generate high-quality speech from text with reference audio styling.
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="Text to synthesize (max chars 300)",
max_lines=5
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
exaggeration = gr.Slider(
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05, label="CFG/Pace", value=0.5
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
vad_trim,
],
outputs=[audio_output],
)
demo.launch(mcp_server=True) |