mamakoRtechs commited on
Commit
d51446f
·
verified ·
1 Parent(s): af25078

app-share.py

Browse files
Files changed (1) hide show
  1. app (1).py +148 -0
app (1).py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🚀 Running on device: {DEVICE}")
10
+
11
+ # --- Global Model Initialization ---
12
+ MODEL = None
13
+
14
+ def get_or_load_model():
15
+ """Loads the ChatterboxTTS model if it hasn't been loaded already,
16
+ and ensures it's on the correct device."""
17
+ global MODEL
18
+ if MODEL is None:
19
+ print("Model not loaded, initializing...")
20
+ try:
21
+ MODEL = ChatterboxTTS.from_pretrained(DEVICE)
22
+ if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
23
+ MODEL.to(DEVICE)
24
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
25
+ except Exception as e:
26
+ print(f"Error loading model: {e}")
27
+ raise
28
+ return MODEL
29
+
30
+ # Attempt to load the model at startup.
31
+ try:
32
+ get_or_load_model()
33
+ except Exception as e:
34
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
35
+
36
+ def set_seed(seed: int):
37
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
38
+ torch.manual_seed(seed)
39
+ if DEVICE == "cuda":
40
+ torch.cuda.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+
45
+ @spaces.GPU
46
+ def generate_tts_audio(
47
+ text_input: str,
48
+ audio_prompt_path_input: str = None,
49
+ exaggeration_input: float = 0.5,
50
+ temperature_input: float = 0.8,
51
+ seed_num_input: int = 0,
52
+ cfgw_input: float = 0.5
53
+ ) -> tuple[int, np.ndarray]:
54
+ """
55
+ Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
56
+
57
+ This tool synthesizes natural-sounding speech from input text. When a reference audio file
58
+ is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
59
+ maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
60
+
61
+ Args:
62
+ text_input (str): The text to synthesize into speech (maximum 300 characters)
63
+ audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
64
+ exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
65
+ temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
66
+ seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
67
+ cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5.
68
+
69
+ Returns:
70
+ tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
71
+ """
72
+ current_model = get_or_load_model()
73
+
74
+ if current_model is None:
75
+ raise RuntimeError("TTS model is not loaded.")
76
+
77
+ if seed_num_input != 0:
78
+ set_seed(int(seed_num_input))
79
+
80
+ print(f"Generating audio for text: '{text_input[:50]}...'")
81
+
82
+ # Handle optional audio prompt
83
+ generate_kwargs = {
84
+ "exaggeration": exaggeration_input,
85
+ "temperature": temperature_input,
86
+ "cfg_weight": cfgw_input,
87
+ }
88
+
89
+ if audio_prompt_path_input:
90
+ generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
91
+
92
+ wav = current_model.generate(
93
+ text_input[:300], # Truncate text to max chars
94
+ **generate_kwargs
95
+ )
96
+ print("Audio generation complete.")
97
+ return (current_model.sr, wav.squeeze(0).numpy())
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown(
101
+ """
102
+ # Chatterbox TTS Demo
103
+ Generate high-quality speech from text with reference audio styling.
104
+ """
105
+ )
106
+ with gr.Row():
107
+ with gr.Column():
108
+ text = gr.Textbox(
109
+ value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
110
+ label="Text to synthesize (max chars 300)",
111
+ max_lines=5
112
+ )
113
+ ref_wav = gr.Audio(
114
+ sources=["upload", "microphone"],
115
+ type="filepath",
116
+ label="Reference Audio File (Optional)",
117
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
118
+ )
119
+ exaggeration = gr.Slider(
120
+ 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
121
+ )
122
+ cfg_weight = gr.Slider(
123
+ 0.2, 1, step=.05, label="CFG/Pace", value=0.5
124
+ )
125
+
126
+ with gr.Accordion("More options", open=False):
127
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
128
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
129
+
130
+ run_btn = gr.Button("Generate", variant="primary")
131
+
132
+ with gr.Column():
133
+ audio_output = gr.Audio(label="Output Audio")
134
+
135
+ run_btn.click(
136
+ fn=generate_tts_audio,
137
+ inputs=[
138
+ text,
139
+ ref_wav,
140
+ exaggeration,
141
+ temp,
142
+ seed_num,
143
+ cfg_weight,
144
+ ],
145
+ outputs=[audio_output],
146
+ )
147
+
148
+ demo.launch(mcp_server=True)