File size: 5,444 Bytes
9d593b2
 
 
9386371
9d593b2
0b3e025
9d593b2
6190e8e
 
 
 
 
 
 
 
 
 
 
 
9d593b2
0b3e025
 
 
 
 
 
9386371
 
0b3e025
 
9386371
0b3e025
 
6190e8e
 
9386371
0b3e025
9386371
0b3e025
9386371
0b3e025
 
 
6190e8e
0b3e025
 
 
9386371
9d593b2
 
9386371
9d593b2
0b3e025
3dab9c0
 
9d593b2
 
 
6190e8e
9386371
 
af25078
 
 
 
1afd111
 
9386371
 
af25078
9386371
 
0b3e025
 
9386371
0b3e025
 
 
 
9386371
1afd111
af25078
 
 
 
 
1afd111
af25078
1afd111
af25078
 
1afd111
0b3e025
9386371
af25078
9d593b2
3dab9c0
0b3e025
9d593b2
 
9386371
 
 
 
 
 
9d593b2
 
9386371
 
 
 
 
 
 
 
 
bf4bbc3
9386371
 
 
 
 
 
 
58ffee2
9d593b2
 
9386371
1afd111
9d593b2
 
 
 
 
 
 
9386371
9d593b2
 
 
 
 
 
58ffee2
1afd111
9d593b2
9386371
9d593b2
 
6190e8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces

# ─── Global patch to fix CUDA deserialization error on CPU ───
# This forces map_location='cpu' on all torch.load calls when CUDA is unavailable
original_torch_load = torch.load

def patched_torch_load(*args, **kwargs):
    if 'map_location' not in kwargs and not torch.cuda.is_available():
        kwargs['map_location'] = torch.device('cpu')
    return original_torch_load(*args, **kwargs)

torch.load = patched_torch_load
# ─────────────────────────────────────────────────────────────

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Running on device: {DEVICE}")

# --- Global Model Initialization ---
MODEL = None

def get_or_load_model():
    """Loads the ChatterboxTTS model if it hasn't been loaded already,
    and ensures it's on the correct device."""
    global MODEL
    if MODEL is None:
        print("Model not loaded, initializing...")
        try:
            MODEL = ChatterboxTTS.from_pretrained(DEVICE)
            # On CPU, .to(DEVICE) is usually redundant after loading with map_location
            # but we keep it for safety / future GPU support
            if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
                MODEL.to(DEVICE)
            print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    return MODEL

# Attempt to load the model at startup (helps catch errors early in logs)
try:
    get_or_load_model()
except Exception as e:
    print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")

def set_seed(seed: int):
    """Sets the random seed for reproducibility across torch, numpy, and random."""
    torch.manual_seed(seed)
    if DEVICE == "cuda":
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

@spaces.GPU  # harmless on CPU, ignored by HF when no GPU is allocated
def generate_tts_audio(
    text_input: str,
    audio_prompt_path_input: str = None,
    exaggeration_input: float = 0.5,
    temperature_input: float = 0.8,
    seed_num_input: int = 0,
    cfgw_input: float = 0.5,
    vad_trim_input: bool = False,
) -> tuple[int, np.ndarray]:
    """
    Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
    """
    current_model = get_or_load_model()

    if current_model is None:
        raise RuntimeError("TTS model is not loaded.")

    if seed_num_input != 0:
        set_seed(int(seed_num_input))

    print(f"Generating audio for text: '{text_input[:50]}...'")

    # Handle optional audio prompt
    generate_kwargs = {
        "exaggeration": exaggeration_input,
        "temperature": temperature_input,
        "cfg_weight": cfgw_input,
        "vad_trim": vad_trim_input,
    }

    if audio_prompt_path_input:
        generate_kwargs["audio_prompt_path"] = audio_prompt_path_input

    wav = current_model.generate(
        text_input[:300],  # Truncate text to max chars
        **generate_kwargs
    )
    print("Audio generation complete.")
    return (current_model.sr, wav.squeeze(0).numpy())

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Chatterbox TTS Demo
        Generate high-quality speech from text with reference audio styling.
        """
    )
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(
                value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
                label="Text to synthesize (max chars 300)",
                max_lines=5
            )
            ref_wav = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Reference Audio File (Optional)",
                value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
            )
            exaggeration = gr.Slider(
                0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
            )
            cfg_weight = gr.Slider(
                0.2, 1, step=.05, label="CFG/Pace", value=0.5
            )

            with gr.Accordion("More options", open=False):
                seed_num = gr.Number(value=0, label="Random seed (0 for random)")
                temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
                vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)

            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(label="Output Audio")

    run_btn.click(
        fn=generate_tts_audio,
        inputs=[
            text,
            ref_wav,
            exaggeration,
            temp,
            seed_num,
            cfg_weight,
            vad_trim,
        ],
        outputs=[audio_output],
    )

demo.launch(mcp_server=True)