oicui commited on
Commit
f8b6238
·
verified ·
1 Parent(s): c22869a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
7
+ import gradio as gr
8
+ import spaces
9
+
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print(f"🚀 Running on device: {DEVICE}")
12
+
13
+ MODEL = None
14
+
15
+ LANGUAGE_CONFIG = {
16
+ "ar": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
17
+ "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."},
18
+ "en": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
19
+ "text": "Last month, we reached a new milestone with two billion views on our YouTube channel."},
20
+ "fr": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
21
+ "text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."},
22
+ "hi": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
23
+ "text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"},
24
+ "tr": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
25
+ "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."},
26
+ "zh": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
27
+ "text": "上个月,我们达到了一个新的里程碑。 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"},
28
+ }
29
+
30
+ def default_audio_for_ui(lang: str) -> str | None:
31
+ return LANGUAGE_CONFIG.get(lang, {}).get("audio")
32
+
33
+ def default_text_for_ui(lang: str) -> str:
34
+ return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
35
+
36
+ def get_supported_languages_display() -> str:
37
+ items = [f"**{name}** (`{code}`)" for code, name in sorted(SUPPORTED_LANGUAGES.items())]
38
+ mid = len(items)//2
39
+ return f"### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)\n" \
40
+ f"{' • '.join(items[:mid])}\n\n{' • '.join(items[mid:])}"
41
+
42
+ def get_or_load_model():
43
+ global MODEL
44
+ if MODEL is None:
45
+ print("Model not loaded, initializing...")
46
+ MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
47
+ if hasattr(MODEL, "to"):
48
+ MODEL.to(DEVICE)
49
+ print(f"✅ Model loaded successfully on {DEVICE}")
50
+ return MODEL
51
+
52
+ try:
53
+ get_or_load_model()
54
+ except Exception as e:
55
+ print(f"CRITICAL: Failed to load model. Error: {e}")
56
+
57
+ def set_seed(seed: int):
58
+ torch.manual_seed(seed)
59
+ if DEVICE == "cuda":
60
+ torch.cuda.manual_seed(seed)
61
+ torch.cuda.manual_seed_all(seed)
62
+ random.seed(seed)
63
+ np.random.seed(seed)
64
+
65
+ def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
66
+ if provided_path and str(provided_path).strip():
67
+ return provided_path
68
+ return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
69
+
70
+ # --- text splitter ---
71
+ def split_text_into_chunks(text: str, max_chars: int = 500) -> list[str]:
72
+ text = re.sub(r"\s+", " ", text.strip())
73
+ if len(text) <= max_chars:
74
+ return [text]
75
+
76
+ sentences = re.split(r'(?<=[.!?।،])\s+', text)
77
+ chunks, current_chunk = [], ""
78
+
79
+ for sent in sentences:
80
+ if len(current_chunk) + len(sent) < max_chars:
81
+ current_chunk += " " + sent
82
+ else:
83
+ chunks.append(current_chunk.strip())
84
+ current_chunk = sent
85
+ if current_chunk:
86
+ chunks.append(current_chunk.strip())
87
+
88
+ return [c for c in chunks if c]
89
+
90
+
91
+ @spaces.GPU
92
+ def generate_tts_audio(
93
+ text_input: str,
94
+ language_id: str,
95
+ audio_prompt_path_input: str = None,
96
+ exaggeration_input: float = 0.5,
97
+ temperature_input: float = 0.8,
98
+ seed_num_input: int = 0,
99
+ cfgw_input: float = 0.5
100
+ ):
101
+
102
+ current_model = get_or_load_model()
103
+ if current_model is None:
104
+ raise RuntimeError("TTS model not loaded.")
105
+
106
+ # --- SEED LOGIC ---
107
+ if seed_num_input == 0:
108
+ seed_num_input = random.randint(1, 2**32 - 1)
109
+ print(f"🌱 Random seed generated: {seed_num_input}")
110
+ else:
111
+ print(f"🌱 Using provided seed: {seed_num_input}")
112
+
113
+ set_seed(int(seed_num_input))
114
+
115
+ chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
116
+ generate_kwargs = {
117
+ "exaggeration": exaggeration_input,
118
+ "temperature": temperature_input,
119
+ "cfg_weight": cfgw_input,
120
+ }
121
+ if chosen_prompt:
122
+ generate_kwargs["audio_prompt_path"] = chosen_prompt
123
+
124
+ chunks = split_text_into_chunks(text_input)
125
+ all_audio = []
126
+
127
+ for chunk in chunks:
128
+ wav = current_model.generate(chunk, language_id=language_id, **generate_kwargs)
129
+ all_audio.append(wav.squeeze(0).cpu())
130
+
131
+ final_audio = torch.cat(all_audio, dim=-1)
132
+
133
+ # RETURN AUDIO + SEED
134
+ return (current_model.sr, final_audio.numpy()), str(seed_num_input)
135
+
136
+
137
+
138
+ # ============================
139
+ # GRADIO UI
140
+ # ============================
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("""
144
+ # 🎙️ Multi Language Realistic Voice Cloner
145
+ Generate long-form multilingual speech with reference audio styling and auto-chunking.
146
+ """)
147
+
148
+ gr.Markdown(get_supported_languages_display())
149
+
150
+ with gr.Row():
151
+ with gr.Column():
152
+ initial_lang = "en"
153
+ text = gr.Textbox(
154
+ value=default_text_for_ui(initial_lang),
155
+ label="Text to synthesize",
156
+ lines=8
157
+ )
158
+ language_id = gr.Dropdown(
159
+ choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
160
+ value=initial_lang,
161
+ label="Language"
162
+ )
163
+ ref_wav = gr.Audio(
164
+ sources=["upload", "microphone"],
165
+ type="filepath",
166
+ label="Reference Audio (Optional)",
167
+ value=default_audio_for_ui(initial_lang)
168
+ )
169
+ exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
170
+ cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG Weight", value=0.5)
171
+
172
+ with gr.Accordion("Advanced", open=False):
173
+ seed_num = gr.Number(value=0, label="Random Seed (0=random)")
174
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
175
+
176
+ run_btn = gr.Button("Generate", variant="primary")
177
+
178
+ # OUTPUT COLUMN
179
+ with gr.Column():
180
+ audio_output = gr.Audio(label="Output Audio")
181
+ seed_output = gr.Textbox(label="Seed Used", interactive=False)
182
+
183
+ def on_lang_change(lang, current_ref, current_text):
184
+ return default_audio_for_ui(lang), default_text_for_ui(lang)
185
+
186
+ language_id.change(
187
+ fn=on_lang_change,
188
+ inputs=[language_id, ref_wav, text],
189
+ outputs=[ref_wav, text],
190
+ show_progress=False
191
+ )
192
+
193
+ # CONNECT BUTTON
194
+ run_btn.click(
195
+ fn=generate_tts_audio,
196
+ inputs=[text, language_id, ref_wav, exaggeration, temp, seed_num, cfg_weight],
197
+ outputs=[audio_output, seed_output],
198
+ )
199
+
200
+ demo.launch(mcp_server=True, share=True)