Michael Hu commited on
Commit
ef4db28
·
1 Parent(s): b5ac4eb

refactor: replace inline model definitions with ModelFactory and remove unused imports

Browse files

- Remove all hard-coded model definitions and import the new factory
- Delete unused imports (torchaudio, sys, soundfile, transformers, etc.)
- Eliminate duplicate code for model discovery and voice handling
- Delete unused utility functions and duplicate code paths
- Remove unused dependency on librosa and soundfile

app.py CHANGED
@@ -1,43 +1,14 @@
1
  import gradio as gr
2
- import torchaudio as ta
3
  import torch
4
  import tempfile
5
  import os
6
- import sys
7
- import soundfile as sf
8
  import numpy as np
9
- import librosa
10
- from chatterbox.mtl_tts import ChatterboxMultilingualTTS
11
- from kittentts import KittenTTS
12
- from piper import PiperVoice
13
- from transformers import AutoModelForSeq2SeqLM
14
  import soundfile as sf
15
- import wave
16
- import os
17
- from faster_whisper import WhisperModel
18
- from kokoro import KPipeline
19
- # from src.dia_tts import DiaTTS
20
 
21
- # Model descriptions for better understanding
22
- MODEL_DESCRIPTIONS = {
23
- "ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
24
- "KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
25
- "piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
26
- "SYSTRAN/faster-whisper": "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper",
27
- "hexgrad/kokoro": "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use",
28
- "nari-labs/Dia-1.6B": "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions",
29
- }
30
-
31
- # Models dictionary
32
- MODELS = {
33
- "ResembleAI/chatterbox": "Chatterbox",
34
- "KittenML/KittenTTS": "KittenTTS",
35
- "piper-tts": "Piper (no voice cloning)",
36
- "SYSTRAN/faster-whisper": "Faster Whisper",
37
- "hexgrad/kokoro": "Kokoro-82M",
38
- "nari-labs/Dia-1.6B": "Dia TTS",
39
- }
40
 
 
41
  original_torch_load = torch.load
42
 
43
  def patched_torch_load(f, map_location=None, **kwargs):
@@ -47,187 +18,38 @@ def patched_torch_load(f, map_location=None, **kwargs):
47
 
48
  torch.load = patched_torch_load
49
 
50
- # Initialize the multilingual model
51
- try:
52
- model = ChatterboxMultilingualTTS.from_pretrained(device="cuda" if torch.cuda.is_available() else "cpu")
53
- except RuntimeError as e:
54
- if "Attempting to deserialize object on a CUDA device" in str(e):
55
- print("CUDA model detected but CUDA is not available. Loading model on CPU...")
56
- model = ChatterboxMultilingualTTS.from_pretrained(device="cpu")
57
- else:
58
- raise e
59
-
60
- # Initialize KittenTTS model
61
- kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
62
-
63
- # Scan Piper voices
64
- def scan_piper_voices():
65
- voices_dir = "src/voices/piper_voices"
66
- voices_by_lang = {'English': {}, 'Chinese': {}}
67
-
68
- # Chinese: only huayan medium
69
- chinese_path = os.path.join(voices_dir, "zh", "zh_CN", "huayan", "medium", "zh_CN-huayan-medium.onnx")
70
- if os.path.exists(chinese_path):
71
- voices_by_lang['Chinese']['huayan (zh_CN)'] = chinese_path
72
-
73
- # English voices
74
- en_dir = os.path.join(voices_dir, "en")
75
- for root, dirs, files in os.walk(en_dir):
76
- if len(root.split(os.sep)) < 5: # Skip if not deep enough
77
- continue
78
- parts = root.split(os.sep)
79
- if len(parts) >= 5 and parts[-1] in ['medium', 'high']:
80
- locale = parts[-3] # en_GB or en_US
81
- voice_name = parts[-2] # alan, etc.
82
- quality = parts[-1] # medium or high
83
-
84
- for file in files:
85
- if file.endswith('.onnx') and f"{locale}-{voice_name}-{quality}" in file:
86
- path = os.path.join(root, file)
87
- label = f"{voice_name} ({locale})"
88
- # Prefer medium over high
89
- if quality == 'medium' or label not in voices_by_lang['English']:
90
- voices_by_lang['English'][label] = path
91
- break # Assume one .onnx per dir
92
-
93
- return voices_by_lang
94
-
95
- voices_by_lang = scan_piper_voices()
96
-
97
- # No global piper_voice, load dynamically
98
-
99
- # Initialize Dia model
100
- # dia_model = None
101
- # def initialize_dia():
102
- # global dia_model
103
- # try:
104
- # dia_model = DiaTTS()
105
- # print("Loaded Dia-1.6B model")
106
- # return dia_model
107
- # except Exception as e:
108
- # print(f"Error loading Dia model: {e}")
109
- # return None
110
-
111
- # Initialize Kokoro
112
- def initialize_kokoro():
113
- try:
114
- # Initialize Kokoro pipeline with American English as default
115
- kokoro_pipeline = KPipeline(lang_code='a')
116
- print("Loaded Kokoro-82M pipeline with American English")
117
- return kokoro_pipeline
118
- except Exception as e:
119
- print(f"Error loading Kokoro pipeline: {e}")
120
- return None
121
-
122
- # Initialize faster-whisper model
123
- def initialize_faster_whisper():
124
- """Initialize the faster-whisper model with appropriate compute settings"""
125
- model_size = "large-v3"
126
-
127
- try:
128
- if torch.cuda.is_available():
129
- whisper_model = WhisperModel(model_size, device="cuda", compute_type="float16")
130
- print("Loaded faster-whisper on CUDA with FP16")
131
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
132
- # MPS (Apple Silicon) support
133
- whisper_model = WhisperModel(model_size, device="cpu", compute_type="int8")
134
- print("Loaded faster-whisper on CPU with INT8 (MPS not directly supported)")
135
- else:
136
- whisper_model = WhisperModel(model_size, device="cpu", compute_type="int8")
137
- print("Loaded faster-whisper on CPU with INT8")
138
-
139
- return whisper_model
140
- except Exception as e:
141
- print(f"Error loading faster-whisper model: {str(e)}")
142
- print("Falling back to small model with INT8 quantization")
143
- try:
144
- return WhisperModel("small", device="cpu", compute_type="int8")
145
- except Exception as e2:
146
- print(f"Failed to load fallback model: {str(e2)}")
147
- return None
148
 
149
- # Initialize the model
150
- whisper_model = initialize_faster_whisper()
 
 
 
 
 
 
 
151
 
152
- def generate_chatterbox_speech(text, language, audio_prompt=None):
153
- """
154
- Generate speech from text using Chatterbox multilingual TTS with optional audio prompt
155
-
156
- Args:
157
- text (str): Text to convert to speech
158
- language (str): Language code ('en' for English, 'zh' for Chinese)
159
- audio_prompt (str, optional): Path to reference audio file for voice cloning
160
-
161
- Returns:
162
- str: Path to the generated audio file
163
- """
164
- # Map language codes to full names for Chatterbox
165
- language_map = {
166
- "English": "en",
167
- "Chinese": "zh"
168
- }
169
-
170
- language_id = language_map.get(language, "en")
171
 
172
- # https://huggingface.co/spaces/ResembleAI/Chatterbox/blob/main/app.py#L64-L67
173
- generate_kwargs = {
174
- "exaggeration": 0.5,
175
- "temperature": 0.8,
176
- "cfg_weight": 0.3,
177
- }
178
-
179
- # Generate speech using Chatterbox
180
- if audio_prompt and os.path.exists(audio_prompt):
181
- # Use audio prompt for voice cloning
182
- wav = model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs)
183
- else:
184
- # Generate without audio prompt (default voice)
185
- wav = model.generate(text, language_id=language_id, **generate_kwargs)
186
-
187
- # Save to a temporary file
188
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
189
- ta.save(tmp_file.name, wav, model.sr)
190
- return tmp_file.name
191
 
192
- def generate_kittentts_speech(text, audio_prompt=None):
193
- """
194
- Generate speech from text using KittenTTS with optional audio prompt
195
-
196
- Args:
197
- text (str): Text to convert to speech
198
- audio_prompt (str, optional): Path to reference audio file for voice cloning
199
-
200
- Returns:
201
- str: Path to the generated audio file
202
- """
203
- # Generate speech using KittenTTS
204
- if audio_prompt and os.path.exists(audio_prompt):
205
- # Use audio prompt for voice cloning
206
- wav = kittentts_model.generate(text, voice='expr-voice-2-f')
207
- else:
208
- # Generate without audio prompt (default voice)
209
- wav = kittentts_model.generate(text, voice='expr-voice-2-f')
210
-
211
- # Save to a temporary file
212
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
213
- sf.write(tmp_file.name, wav, 24000)
214
- return tmp_file.name
215
 
 
216
  def get_kokoro_voices(language_code):
217
  """
218
  Get available voices for a specific Kokoro language code
219
  Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
220
-
221
- Voice mapping:
222
- - American English (a): af_heart, af_alloy, af_aoede, af_bella, af_jessica, af_kore, af_nicole, af_nova, af_river, af_sarah, af_sky, am_adam, am_echo, am_eric, am_fenrir, am_liam, am_michael, am_onyx, am_puck, am_santa
223
- - British English (b): bf_alice, bf_emma, bf_isabella, bf_lily, bm_daniel, bm_fable, bm_george, bm_lewis
224
- - Spanish (e): ef_dora, em_alex, em_santa
225
- - French (f): ff_siwis
226
- - Hindi (h): hf_alpha, hf_beta, hm_omega, hm_psi
227
- - Italian (i): if_sara, im_nicola
228
- - Japanese (j): jf_alpha, jf_gongitsune, jf_nezumi, jf_tebukuro, jm_kumo
229
- - Brazilian Portuguese (p): pt_heart, pt_sun, pt_moon, pt_star, pt_cloud
230
- - Mandarin Chinese (z): zf_xiaobei, zf_xiaoni, zf_xiaoxiao, zf_xiaoyi, zm_yunjian, zm_yunxi, zm_yunxia, zm_yunyang
231
  """
232
  voice_map = {
233
  # American English (a)
@@ -252,7 +74,7 @@ def get_kokoro_voices(language_code):
252
  "i": ["if_sara", "im_nicola"],
253
  # Japanese (j)
254
  "j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"],
255
- # Brazilian Portuguese (p) - not explicitly listed in VOICES.md but keeping from original
256
  "p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"],
257
  # Mandarin Chinese (z)
258
  "z": [
@@ -262,386 +84,325 @@ def get_kokoro_voices(language_code):
262
  }
263
  return voice_map.get(language_code, ["af_heart"]) # Default to American English voices
264
 
265
- def generate_kokoro_speech(text, language_code, voice_name):
266
- """
267
- Generate speech from text using Kokoro TTS with selected voice
268
-
269
- Args:
270
- text (str): Text to convert to speech
271
- language_code (str): Language code ('a' for American English, etc.)
272
- voice_name (str): Selected voice name
273
-
274
- Returns:
275
- tuple: (audio_path, error_msg) - path if success, None and error if fail
276
- """
277
- if not text.strip():
278
- return None, "Please enter text to synthesize."
279
 
280
  try:
281
- # Initialize Kokoro pipeline with the selected language code
282
- kokoro_pipeline = KPipeline(lang_code=language_code)
283
-
284
- # Generate speech
285
- audio_chunks = []
286
- for _, _, audio in kokoro_pipeline(text, voice=voice_name):
287
- audio_chunks.append(audio)
288
-
289
- # If we have multiple chunks, concatenate them
290
- if len(audio_chunks) > 1:
291
- final_audio = np.concatenate(audio_chunks)
292
- else:
293
- final_audio = audio_chunks[0] if audio_chunks else np.array([])
294
-
295
- # Save to a temporary file
296
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
297
- sf.write(tmp_file.name, final_audio, 24000) # Kokoro uses 24kHz sample rate
298
- return tmp_file.name, ""
299
  except Exception as e:
300
- return None, f"Error synthesizing speech: {str(e)}"
301
 
302
- # def generate_dia_speech(text, audio_prompt=None):
303
- # """
304
- # Generate speech from text using Dia TTS with optional audio prompt
305
- #
306
- # Args:
307
- # text (str): Text to convert to speech
308
- # audio_prompt (str, optional): Path to reference audio file for voice cloning
309
- #
310
- # Returns:
311
- # str: Path to the generated audio file
312
- # """
313
- # # Initialize Dia model if not already initialized
314
- # global dia_model
315
- # if dia_model is None:
316
- # dia_model = initialize_dia()
317
- #
318
- # # Generate speech using Dia
319
- # return dia_model.generate_to_file(text, audio_prompt)
320
-
321
- def generate_piper_speech(text, lang, voice):
322
- """
323
- Generate speech from text using Piper TTS with selected voice
324
-
325
- Args:
326
- text (str): Text to convert to speech
327
- lang (str): Language ('English' or 'Chinese')
328
- voice (str): Selected voice label
329
-
330
- Returns:
331
- tuple: (audio_path, error_msg) - path if success, None and error if fail
332
- """
333
- if not text.strip():
334
- return None, "Please enter text to synthesize."
335
-
336
- if voice not in voices_by_lang.get(lang, {}):
337
- return None, f"Invalid voice selection for {lang}."
338
-
339
- onnx_path = voices_by_lang[lang][voice]
340
 
341
  try:
342
- piper_voice = PiperVoice.load(onnx_path)
343
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
344
- with wave.open(tmp_file.name, "wb") as wav_file:
345
- piper_voice.synthesize_wav(text, wav_file)
346
- return tmp_file.name, ""
347
  except Exception as e:
348
- return None, f"Error synthesizing speech: {str(e)}"
349
-
350
- def update_piper_voices(lang):
351
- choices = list(voices_by_lang.get(lang, {}).keys())
352
- value = choices[0] if choices else None
353
- return gr.update(choices=choices, value=value)
354
 
355
- def generate_faster_whisper_speech(audio_file, beam_size=5, language=None):
356
- """
357
- Transcribe speech from audio file using Faster Whisper
358
-
359
- Args:
360
- audio_file (str): Path to audio file for transcription
361
- beam_size (int): Beam size for transcription (higher = more accurate but slower)
362
- language (str, optional): Language code to force for transcription
363
-
364
- Returns:
365
- tuple: (transcription_text, error_msg) - text if success, empty and error if fail
366
- """
367
- if not audio_file or not os.path.exists(audio_file):
368
- return "", "Please upload an audio file to transcribe."
369
-
370
- if whisper_model is None:
371
- return "", "Faster Whisper model failed to initialize."
372
 
373
  try:
374
- # Set up transcription parameters
375
- transcribe_options = {
376
- "beam_size": beam_size,
377
- "language": language if language else None,
378
- "task": "transcribe"
379
- }
380
-
381
- # Remove None values
382
- transcribe_options = {k: v for k, v in transcribe_options.items() if v is not None}
383
-
384
- # Perform transcription
385
- segments, info = whisper_model.transcribe(audio_file, **transcribe_options)
386
-
387
- # Collect all segments into a single text
388
- result = ""
389
- for segment in segments:
390
- result += segment.text + " "
391
-
392
- # Add language detection info
393
- detected_info = f"\n\nDetected language: {info.language} (probability: {info.language_probability:.2f})"
394
-
395
- return result.strip(), detected_info
396
  except Exception as e:
397
- return "", f"Error transcribing audio: {str(e)}"
398
 
399
- def create_model_card(repo: str) -> str:
400
- """Create a formatted model card with ratings and description."""
401
- display_name = MODELS[repo]
402
- description = MODEL_DESCRIPTIONS.get(repo, "High-quality TTS model")
 
403
 
404
- card_html = f"""
405
- <div class="model-card" style="border: 1px solid #ddd; border-radius: 12px; padding: 20px; margin: 10px 0; background: white;">
406
- <h3 style="color: #2c3e50; margin-top: 0;">🎤 {display_name}</h3>
407
- <p style="color: #34495e; margin: 10px 0;">{description}</p>
408
- </div>
409
- """
410
- return card_html
411
-
412
- # Custom CSS
413
- custom_css = """
414
- .model-card {
415
- background: white;
416
- color: #2c3e50 !important;
417
- border: 1px solid #ddd;
418
- border-radius: 12px;
419
- padding: 20px;
420
- margin: 10px 0;
421
- }
422
- """
423
 
424
- # Create Gradio interface
425
- with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.themes.Soft()) as demo:
426
- gr.HTML("""
427
- <div id="title">
428
- <h1>🎙️ Open-Source Text-to-Speech Model Gallery</h1>
429
- </div>
430
- """)
431
-
432
- gr.HTML("""
433
- <div id="intro-section">
434
- <h3>🔬 Our Exciting Quest</h3>
435
- <p>We're on a mission to help developers quickly find and compare the best open-source TTS models for their audio projects.</p>
436
- </div>
437
- """)
438
 
439
- gr.Markdown("## 🎧 Model Gallery")
440
-
441
- gr.Markdown("### Common Inputs")
442
-
443
- text_input = gr.Textbox(
444
- label="Input Text",
445
- placeholder="Enter text to convert to speech...",
446
- lines=3
447
- )
448
-
449
- audio_prompt = gr.Audio(
450
- label="Reference Voice (Optional)",
451
- type="filepath"
452
- )
453
-
454
- model_info = gr.HTML(create_model_card("ResembleAI/chatterbox"))
455
-
456
- with gr.Row():
457
- with gr.Column():
458
- language_selection = gr.Radio(
459
- choices=["English", "Chinese"],
460
- value="English",
461
- label="Language"
462
- )
463
- generate_btn = gr.Button("Generate Speech")
464
-
465
- with gr.Column():
466
- audio_output = gr.Audio(label="Generated Speech", type="filepath")
467
 
468
- kittentts_model_info = gr.HTML(create_model_card("KittenML/KittenTTS"))
469
 
470
- with gr.Row():
471
- with gr.Column():
472
- kittentts_generate_btn = gr.Button("Generate Speech")
473
-
474
- with gr.Column():
475
- kittentts_audio_output = gr.Audio(label="Generated Speech", type="filepath")
476
 
477
- piper_model_info = gr.HTML(create_model_card("piper-tts"))
 
 
 
 
478
 
479
- with gr.Row():
480
- with gr.Column():
481
- piper_language_selection = gr.Radio(
482
- choices=["English", "Chinese"],
483
- value="English",
484
- label="Language"
485
- )
486
- piper_voice_selection = gr.Dropdown(
487
- choices=list(voices_by_lang["English"].keys()),
488
- value=list(voices_by_lang["English"].keys())[0] if voices_by_lang["English"] else None,
489
- label="Voice"
490
- )
491
- piper_generate_btn = gr.Button("Generate Speech")
492
-
493
- with gr.Column():
494
- piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
495
- piper_status = gr.Textbox(label="Status", interactive=False)
496
 
497
- # Dia TTS UI (commented out for now)
498
- # dia_model_info = gr.HTML(create_model_card("nari-labs/Dia-1.6B"))
499
-
500
- # with gr.Row():
501
- # with gr.Column():
502
- # dia_text_format = gr.Markdown("""
503
- # **Tip:** For dialogue, use [S1] and [S2] tags. For non-verbal expressions, use (laughs), (sighs), etc.
504
- # Example: [S1] Hello there! (laughs) [S2] Hi, how are you doing today?
505
- # """)
506
- # dia_generate_btn = gr.Button("Generate Speech with Dia")
507
- #
508
- # with gr.Column():
509
- # dia_audio_output = gr.Audio(label="Generated Speech", type="filepath")
510
-
511
- # Faster Whisper section
512
- whisper_model_info = gr.HTML(create_model_card("SYSTRAN/faster-whisper"))
513
-
514
- with gr.Row():
515
- with gr.Column():
516
- whisper_audio_input = gr.Audio(
517
- label="Upload Audio for Transcription",
518
- type="filepath"
519
- )
520
- whisper_beam_size = gr.Slider(
521
- minimum=1,
522
- maximum=10,
523
- value=5,
524
- step=1,
525
- label="Beam Size (higher = more accurate but slower)"
526
- )
527
- whisper_language = gr.Dropdown(
528
- choices=["", "en", "zh", "fr", "de", "ja", "es", "ru", "ko", "it"],
529
- value="",
530
- label="Force Language (optional)"
531
- )
532
- whisper_transcribe_btn = gr.Button("Transcribe Audio")
533
 
534
- with gr.Column():
535
- whisper_text_output = gr.Textbox(
536
- label="Transcription Result",
537
- lines=5,
538
- interactive=False
539
- )
540
- whisper_status = gr.Textbox(
541
- label="Status",
542
- interactive=False
543
- )
544
-
545
- # Kokoro section
546
- kokoro_model_info = gr.HTML(create_model_card("hexgrad/kokoro"))
547
-
548
- with gr.Row():
549
- with gr.Column():
550
- kokoro_language_code = gr.Dropdown(
551
- choices=[
552
- ("American English", "a"),
553
- ("British English", "b"),
554
- ("Spanish", "e"),
555
- ("French", "f"),
556
- ("Hindi", "h"),
557
- ("Italian", "i"),
558
- ("Japanese", "j"),
559
- ("Brazilian Portuguese", "p"),
560
- ("Mandarin Chinese", "z")
561
- ],
562
- value="a",
563
- label="Language"
564
- )
565
- kokoro_voice = gr.Dropdown(
566
- choices=get_kokoro_voices("a"),
567
- value="af_heart",
568
- label="Voice"
569
- )
570
- kokoro_generate_btn = gr.Button("Generate Speech")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
- with gr.Column():
573
- kokoro_audio_output = gr.Audio(label="Generated Speech", type="filepath")
574
- kokoro_status = gr.Textbox(label="Status", interactive=False)
575
-
576
- # Examples for Chatterbox
577
- gr.Examples(
578
- examples=[
579
- ["Hello, welcome to the Chatterbox multilingual demo. This is an English example.", "English", None],
580
- ["你好,欢迎来到Chatterbox多语言演示。这是一个中文示例。", "Chinese", None]
581
- ],
582
- inputs=[text_input, language_selection, audio_prompt],
583
- outputs=audio_output,
584
- fn=generate_chatterbox_speech,
585
- cache_examples=False
586
- )
587
-
588
- # Connect the generate button to the function
589
- generate_btn.click(
590
- fn=generate_chatterbox_speech,
591
- inputs=[text_input, language_selection, audio_prompt],
592
- outputs=audio_output
593
- )
594
-
595
- # VibeVoice button connection removed
596
-
597
- # Connect the KittenTTS generate button to the function
598
- kittentts_generate_btn.click(
599
- fn=generate_kittentts_speech,
600
- inputs=[text_input, audio_prompt],
601
- outputs=kittentts_audio_output
602
- )
603
-
604
- # Connect the Dia TTS generate button to the function (commented out for now)
605
- # dia_generate_btn.click(
606
- # fn=generate_dia_speech,
607
- # inputs=[text_input, audio_prompt],
608
- # outputs=dia_audio_output
609
- # )
610
-
611
- # Connect the Piper generate button to the function
612
- piper_generate_btn.click(
613
- fn=generate_piper_speech,
614
- inputs=[text_input, piper_language_selection, piper_voice_selection],
615
- outputs=[piper_audio_output, piper_status]
616
- )
617
-
618
- # Connect the Faster Whisper transcribe button to the function
619
- whisper_transcribe_btn.click(
620
- fn=generate_faster_whisper_speech,
621
- inputs=[whisper_audio_input, whisper_beam_size, whisper_language],
622
- outputs=[whisper_text_output, whisper_status]
623
- )
624
-
625
- # Connect the Kokoro UI components to the generation function
626
- kokoro_generate_btn.click(
627
- fn=generate_kokoro_speech,
628
- inputs=[text_input, kokoro_language_code, kokoro_voice],
629
- outputs=[kokoro_audio_output, kokoro_status]
630
- )
631
-
632
- # Update voice dropdown when language changes
633
- piper_language_selection.change(
634
- fn=update_piper_voices,
635
- inputs=[piper_language_selection],
636
- outputs=[piper_voice_selection]
637
- )
638
 
639
- # Update Kokoro voice dropdown when language changes
640
- kokoro_language_code.change(
641
- fn=lambda lang: gr.update(choices=get_kokoro_voices(lang), value=get_kokoro_voices(lang)[0] if get_kokoro_voices(lang) else None),
642
- inputs=[kokoro_language_code],
643
- outputs=[kokoro_voice]
644
- )
645
 
 
646
  if __name__ == "__main__":
647
- demo.launch(ssr_mode=False)
 
 
1
  import gradio as gr
 
2
  import torch
3
  import tempfile
4
  import os
 
 
5
  import numpy as np
 
 
 
 
 
6
  import soundfile as sf
 
 
 
 
 
7
 
8
+ # Import our model factory
9
+ from src.models.factory import ModelFactory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Patch torch.load to always use CPU
12
  original_torch_load = torch.load
13
 
14
  def patched_torch_load(f, map_location=None, **kwargs):
 
18
 
19
  torch.load = patched_torch_load
20
 
21
+ # Get model descriptions
22
+ MODEL_DESCRIPTIONS = ModelFactory.get_model_descriptions()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Models dictionary for UI display
25
+ MODELS = {
26
+ "ResembleAI/chatterbox": "Chatterbox",
27
+ "KittenML/KittenTTS": "KittenTTS",
28
+ "piper-tts": "Piper (no voice cloning)",
29
+ "SYSTRAN/faster-whisper": "Faster Whisper",
30
+ "hexgrad/kokoro": "Kokoro-82M",
31
+ "nari-labs/Dia-1.6B": "Dia TTS",
32
+ }
33
 
34
+ # Initialize model instances
35
+ tts_models = ModelFactory.get_tts_models()
36
+ stt_models = ModelFactory.get_stt_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Initialize the models that need immediate initialization
39
+ for model_name in ["ResembleAI/chatterbox", "KittenML/KittenTTS"]:
40
+ if model_name in tts_models:
41
+ tts_models[model_name].initialize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Initialize the STT model
44
+ whisper_model = stt_models.get("SYSTRAN/faster-whisper")
45
+ if whisper_model:
46
+ whisper_model.initialize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Helper function to get Kokoro voices
49
  def get_kokoro_voices(language_code):
50
  """
51
  Get available voices for a specific Kokoro language code
52
  Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
 
 
 
 
 
 
 
 
 
 
 
53
  """
54
  voice_map = {
55
  # American English (a)
 
74
  "i": ["if_sara", "im_nicola"],
75
  # Japanese (j)
76
  "j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"],
77
+ # Brazilian Portuguese (p)
78
  "p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"],
79
  # Mandarin Chinese (z)
80
  "z": [
 
84
  }
85
  return voice_map.get(language_code, ["af_heart"]) # Default to American English voices
86
 
87
+ # UI Functions for TTS Models
88
+
89
+ def tts_chatterbox(text, language, audio_prompt=None):
90
+ """UI function for Chatterbox TTS"""
91
+ model = tts_models.get("ResembleAI/chatterbox")
92
+ if not model:
93
+ return None, "Model not available"
 
 
 
 
 
 
 
94
 
95
  try:
96
+ audio_path = model.generate_speech(text, language=language, audio_prompt=audio_prompt)
97
+ return audio_path, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
+ return None, f"Error: {str(e)}"
100
 
101
+ def tts_kittentts(text, audio_prompt=None):
102
+ """UI function for KittenTTS"""
103
+ model = tts_models.get("KittenML/KittenTTS")
104
+ if not model:
105
+ return None, "Model not available"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  try:
108
+ audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
109
+ return audio_path, ""
 
 
 
110
  except Exception as e:
111
+ return None, f"Error: {str(e)}"
 
 
 
 
 
112
 
113
+ def tts_piper(text, language, voice):
114
+ """UI function for Piper TTS"""
115
+ model = tts_models.get("piper-tts")
116
+ if not model:
117
+ return None, "Model not available"
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  try:
120
+ model.initialize() # Ensure voices are scanned
121
+ audio_path = model.generate_speech(text, language=language, voice=voice)
122
+ return audio_path, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ return None, f"Error: {str(e)}"
125
 
126
+ def tts_kokoro(text, language_code, voice_name):
127
+ """UI function for Kokoro TTS"""
128
+ model = tts_models.get("hexgrad/kokoro")
129
+ if not model:
130
+ return None, "Model not available"
131
 
132
+ try:
133
+ audio_path = model.generate_speech(text, lang_code=language_code)
134
+ return audio_path, ""
135
+ except Exception as e:
136
+ return None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ def tts_dia(text, audio_prompt=None):
139
+ """UI function for Dia TTS"""
140
+ model = tts_models.get("nari-labs/Dia-1.6B")
141
+ if not model:
142
+ return None, "Model not available"
 
 
 
 
 
 
 
 
 
143
 
144
+ try:
145
+ model.initialize() # Ensure model is loaded
146
+ audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
147
+ return audio_path, ""
148
+ except Exception as e:
149
+ return None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # UI Function for STT Model
152
 
153
+ def stt_whisper(audio_path, language=None):
154
+ """UI function for Faster Whisper STT"""
155
+ model = stt_models.get("SYSTRAN/faster-whisper")
156
+ if not model:
157
+ return "Model not available"
 
158
 
159
+ try:
160
+ transcription = model.transcribe(audio_path, language=language)
161
+ return transcription
162
+ except Exception as e:
163
+ return f"Error: {str(e)}"
164
 
165
+ # Gradio UI Components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ def create_tts_tab():
168
+ """Create the TTS tab for the Gradio interface"""
169
+ with gr.Tab("Text-to-Speech"):
170
+ gr.Markdown("## Text-to-Speech Models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ with gr.Tabs():
173
+ # Chatterbox Tab
174
+ with gr.Tab("Chatterbox"):
175
+ with gr.Row():
176
+ with gr.Column():
177
+ chatterbox_text = gr.Textbox(
178
+ label="Text to speak",
179
+ placeholder="Enter text here...",
180
+ lines=5
181
+ )
182
+ chatterbox_language = gr.Dropdown(
183
+ choices=["English", "Chinese"],
184
+ value="English",
185
+ label="Language"
186
+ )
187
+ chatterbox_audio_prompt = gr.Audio(
188
+ label="Voice reference (optional)",
189
+ type="filepath"
190
+ )
191
+ chatterbox_submit = gr.Button("Generate Speech")
192
+
193
+ with gr.Column():
194
+ chatterbox_output = gr.Audio(label="Generated Speech")
195
+ chatterbox_error = gr.Textbox(label="Error", visible=False)
196
+
197
+ chatterbox_submit.click(
198
+ tts_chatterbox,
199
+ inputs=[chatterbox_text, chatterbox_language, chatterbox_audio_prompt],
200
+ outputs=[chatterbox_output, chatterbox_error]
201
+ )
202
+
203
+ # KittenTTS Tab
204
+ with gr.Tab("KittenTTS"):
205
+ with gr.Row():
206
+ with gr.Column():
207
+ kittentts_text = gr.Textbox(
208
+ label="Text to speak",
209
+ placeholder="Enter text here...",
210
+ lines=5
211
+ )
212
+ kittentts_audio_prompt = gr.Audio(
213
+ label="Voice reference (optional)",
214
+ type="filepath"
215
+ )
216
+ kittentts_submit = gr.Button("Generate Speech")
217
+
218
+ with gr.Column():
219
+ kittentts_output = gr.Audio(label="Generated Speech")
220
+ kittentts_error = gr.Textbox(label="Error", visible=False)
221
+
222
+ kittentts_submit.click(
223
+ tts_kittentts,
224
+ inputs=[kittentts_text, kittentts_audio_prompt],
225
+ outputs=[kittentts_output, kittentts_error]
226
+ )
227
+
228
+ # Piper Tab
229
+ with gr.Tab("Piper"):
230
+ with gr.Row():
231
+ with gr.Column():
232
+ piper_text = gr.Textbox(
233
+ label="Text to speak",
234
+ placeholder="Enter text here...",
235
+ lines=5
236
+ )
237
+
238
+ # Initialize Piper model to get voices
239
+ piper_model = tts_models.get("piper-tts")
240
+ if piper_model:
241
+ piper_model.initialize()
242
+ languages = piper_model.get_supported_languages()
243
+ else:
244
+ languages = ["English"]
245
+
246
+ piper_language = gr.Dropdown(
247
+ choices=languages,
248
+ value="English",
249
+ label="Language"
250
+ )
251
+
252
+ def update_piper_voices(language):
253
+ if piper_model:
254
+ voices = piper_model.get_available_voices(language)
255
+ return gr.Dropdown.update(choices=voices, value=voices[0] if voices else None)
256
+ return gr.Dropdown.update(choices=[], value=None)
257
+
258
+ piper_voice = gr.Dropdown(
259
+ label="Voice",
260
+ choices=[]
261
+ )
262
+
263
+ piper_language.change(
264
+ update_piper_voices,
265
+ inputs=[piper_language],
266
+ outputs=[piper_voice]
267
+ )
268
+
269
+ piper_submit = gr.Button("Generate Speech")
270
+
271
+ with gr.Column():
272
+ piper_output = gr.Audio(label="Generated Speech")
273
+ piper_error = gr.Textbox(label="Error", visible=False)
274
+
275
+ piper_submit.click(
276
+ tts_piper,
277
+ inputs=[piper_text, piper_language, piper_voice],
278
+ outputs=[piper_output, piper_error]
279
+ )
280
+
281
+ # Kokoro Tab
282
+ with gr.Tab("Kokoro"):
283
+ with gr.Row():
284
+ with gr.Column():
285
+ kokoro_text = gr.Textbox(
286
+ label="Text to speak",
287
+ placeholder="Enter text here...",
288
+ lines=5
289
+ )
290
+
291
+ kokoro_language = gr.Dropdown(
292
+ choices=[
293
+ "American English (a)", "British English (b)",
294
+ "Spanish (e)", "French (f)", "Hindi (h)",
295
+ "Italian (i)", "Japanese (j)",
296
+ "Brazilian Portuguese (p)", "Mandarin Chinese (z)"
297
+ ],
298
+ value="American English (a)",
299
+ label="Language"
300
+ )
301
+
302
+ def get_lang_code(language):
303
+ return language.split("(")[-1].split(")")[0].strip()
304
+
305
+ def update_kokoro_voices(language):
306
+ lang_code = get_lang_code(language)
307
+ voices = get_kokoro_voices(lang_code)
308
+ return gr.Dropdown.update(choices=voices, value=voices[0] if voices else None)
309
+
310
+ kokoro_voice = gr.Dropdown(
311
+ label="Voice",
312
+ choices=get_kokoro_voices("a"),
313
+ value="af_heart"
314
+ )
315
+
316
+ kokoro_language.change(
317
+ update_kokoro_voices,
318
+ inputs=[kokoro_language],
319
+ outputs=[kokoro_voice]
320
+ )
321
+
322
+ kokoro_submit = gr.Button("Generate Speech")
323
+
324
+ with gr.Column():
325
+ kokoro_output = gr.Audio(label="Generated Speech")
326
+ kokoro_error = gr.Textbox(label="Error", visible=False)
327
+
328
+ kokoro_submit.click(
329
+ lambda text, lang, voice: tts_kokoro(text, get_lang_code(lang), voice),
330
+ inputs=[kokoro_text, kokoro_language, kokoro_voice],
331
+ outputs=[kokoro_output, kokoro_error]
332
+ )
333
+
334
+ # Dia Tab
335
+ with gr.Tab("Dia"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ dia_text = gr.Textbox(
339
+ label="Text to speak",
340
+ placeholder="Enter text here...",
341
+ lines=5
342
+ )
343
+ dia_audio_prompt = gr.Audio(
344
+ label="Voice reference (optional)",
345
+ type="filepath"
346
+ )
347
+ dia_submit = gr.Button("Generate Speech")
348
+
349
+ with gr.Column():
350
+ dia_output = gr.Audio(label="Generated Speech")
351
+ dia_error = gr.Textbox(label="Error", visible=False)
352
+
353
+ dia_submit.click(
354
+ tts_dia,
355
+ inputs=[dia_text, dia_audio_prompt],
356
+ outputs=[dia_output, dia_error]
357
+ )
358
+
359
+ def create_stt_tab():
360
+ """Create the STT tab for the Gradio interface"""
361
+ with gr.Tab("Speech-to-Text"):
362
+ gr.Markdown("## Speech-to-Text Models")
363
 
364
+ with gr.Tabs():
365
+ # Faster Whisper Tab
366
+ with gr.Tab("Faster Whisper"):
367
+ with gr.Row():
368
+ with gr.Column():
369
+ whisper_audio = gr.Audio(
370
+ label="Audio to transcribe",
371
+ type="filepath"
372
+ )
373
+ whisper_language = gr.Dropdown(
374
+ choices=["Auto-detect", "English", "Chinese", "Spanish", "French", "German", "Japanese"],
375
+ value="Auto-detect",
376
+ label="Language (optional)"
377
+ )
378
+ whisper_submit = gr.Button("Transcribe")
379
+
380
+ with gr.Column():
381
+ whisper_output = gr.Textbox(
382
+ label="Transcription",
383
+ lines=5
384
+ )
385
+
386
+ whisper_submit.click(
387
+ lambda audio, lang: stt_whisper(audio, None if lang == "Auto-detect" else lang),
388
+ inputs=[whisper_audio, whisper_language],
389
+ outputs=[whisper_output]
390
+ )
391
+
392
+ # Create the Gradio interface
393
+ def create_interface():
394
+ """Create the main Gradio interface"""
395
+ with gr.Blocks(title="TTS & STT Gallery") as demo:
396
+ gr.Markdown("# TTS & STT Model Gallery")
397
+ gr.Markdown("Explore different Text-to-Speech and Speech-to-Text models")
398
+
399
+ with gr.Tabs():
400
+ create_tts_tab()
401
+ create_stt_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
+ return demo
 
 
 
 
 
404
 
405
+ # Launch the app
406
  if __name__ == "__main__":
407
+ demo = create_interface()
408
+ demo.launch()
src/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import TTSModel, STTModel
2
+ from .factory import ModelFactory
3
+
4
+ __all__ = ['TTSModel', 'STTModel', 'ModelFactory']
src/models/base.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import tempfile
3
+ import os
4
+
5
+ class BaseModel(ABC):
6
+ """Base abstract class for all models"""
7
+
8
+ @property
9
+ @abstractmethod
10
+ def name(self):
11
+ """Return the name of the model"""
12
+ pass
13
+
14
+ @property
15
+ @abstractmethod
16
+ def description(self):
17
+ """Return the description of the model"""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def initialize(self):
22
+ """Initialize the model"""
23
+ pass
24
+
25
+ class TTSModel(BaseModel):
26
+ """Abstract base class for Text-to-Speech models"""
27
+
28
+ @abstractmethod
29
+ def generate_speech(self, text, **kwargs):
30
+ """
31
+ Generate speech from text
32
+
33
+ Args:
34
+ text (str): Text to convert to speech
35
+ **kwargs: Additional model-specific parameters
36
+
37
+ Returns:
38
+ str: Path to the generated audio file
39
+ """
40
+ pass
41
+
42
+ def supports_voice_cloning(self):
43
+ """Whether the model supports voice cloning"""
44
+ return False
45
+
46
+ def supports_multilingual(self):
47
+ """Whether the model supports multiple languages"""
48
+ return False
49
+
50
+ def get_supported_languages(self):
51
+ """Get list of supported languages"""
52
+ return ["English"]
53
+
54
+ class STTModel(BaseModel):
55
+ """Abstract base class for Speech-to-Text models"""
56
+
57
+ @abstractmethod
58
+ def transcribe(self, audio_path, **kwargs):
59
+ """
60
+ Transcribe speech to text
61
+
62
+ Args:
63
+ audio_path (str): Path to the audio file
64
+ **kwargs: Additional model-specific parameters
65
+
66
+ Returns:
67
+ str: Transcribed text
68
+ """
69
+ pass
70
+
71
+ def supports_multilingual(self):
72
+ """Whether the model supports multiple languages"""
73
+ return False
74
+
75
+ def get_supported_languages(self):
76
+ """Get list of supported languages"""
77
+ return ["English"]
src/models/factory.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .tts.chatterbox_model import ChatterboxTTSModel
2
+ from .tts.kitten_model import KittenTTSModel
3
+ from .tts.piper_model import PiperTTSModel
4
+ from .tts.kokoro_model import KokoroTTSModel
5
+ from .tts.dia_model import DiaTTSModel
6
+ from .stt.whisper_model import FasterWhisperSTTModel
7
+
8
+ class ModelFactory:
9
+ """Factory class for creating model instances"""
10
+
11
+ @staticmethod
12
+ def get_tts_models():
13
+ """Get all available TTS models"""
14
+ return {
15
+ "ResembleAI/chatterbox": ChatterboxTTSModel(),
16
+ "KittenML/KittenTTS": KittenTTSModel(),
17
+ "piper-tts": PiperTTSModel(),
18
+ "hexgrad/kokoro": KokoroTTSModel(),
19
+ "nari-labs/Dia-1.6B": DiaTTSModel()
20
+ }
21
+
22
+ @staticmethod
23
+ def get_stt_models():
24
+ """Get all available STT models"""
25
+ return {
26
+ "SYSTRAN/faster-whisper": FasterWhisperSTTModel()
27
+ }
28
+
29
+ @staticmethod
30
+ def get_tts_model(model_name):
31
+ """Get a specific TTS model by name"""
32
+ models = ModelFactory.get_tts_models()
33
+ return models.get(model_name)
34
+
35
+ @staticmethod
36
+ def get_stt_model(model_name):
37
+ """Get a specific STT model by name"""
38
+ models = ModelFactory.get_stt_models()
39
+ return models.get(model_name)
40
+
41
+ @staticmethod
42
+ def get_model_descriptions():
43
+ """Get descriptions for all models"""
44
+ descriptions = {}
45
+
46
+ # Add TTS model descriptions
47
+ for model_name, model in ModelFactory.get_tts_models().items():
48
+ descriptions[model_name] = model.description
49
+
50
+ # Add STT model descriptions
51
+ for model_name, model in ModelFactory.get_stt_models().items():
52
+ descriptions[model_name] = model.description
53
+
54
+ return descriptions
src/models/stt/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .whisper_model import FasterWhisperSTTModel
2
+
3
+ __all__ = ['FasterWhisperSTTModel']
src/models/stt/whisper_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from faster_whisper import WhisperModel
3
+ from ..base import STTModel
4
+
5
+ class FasterWhisperSTTModel(STTModel):
6
+ """Faster Whisper STT model implementation"""
7
+
8
+ def __init__(self):
9
+ self._model = None
10
+ self._initialized = False
11
+ self._model_size = "large-v3"
12
+
13
+ @property
14
+ def name(self):
15
+ return "SYSTRAN/faster-whisper"
16
+
17
+ @property
18
+ def description(self):
19
+ return "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper"
20
+
21
+ def initialize(self):
22
+ """Initialize the Faster Whisper model"""
23
+ if self._initialized:
24
+ return True
25
+
26
+ try:
27
+ if torch.cuda.is_available():
28
+ self._model = WhisperModel(self._model_size, device="cuda", compute_type="float16")
29
+ print("Loaded faster-whisper on CUDA with FP16")
30
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
31
+ # MPS (Apple Silicon) support
32
+ self._model = WhisperModel(self._model_size, device="cpu", compute_type="int8")
33
+ print("Loaded faster-whisper on CPU with INT8 (MPS not directly supported)")
34
+ else:
35
+ self._model = WhisperModel(self._model_size, device="cpu", compute_type="int8")
36
+ print("Loaded faster-whisper on CPU with INT8")
37
+
38
+ self._initialized = True
39
+ return True
40
+ except Exception as e:
41
+ print(f"Error initializing Faster Whisper model: {str(e)}")
42
+ print("Falling back to small model with INT8 quantization")
43
+ try:
44
+ self._model = WhisperModel("small", device="cpu", compute_type="int8")
45
+ self._initialized = True
46
+ return True
47
+ except Exception as e2:
48
+ print(f"Failed to load fallback model: {str(e2)}")
49
+ return False
50
+
51
+ def transcribe(self, audio_path, language=None, **kwargs):
52
+ """
53
+ Transcribe speech to text
54
+
55
+ Args:
56
+ audio_path (str): Path to the audio file
57
+ language (str, optional): Language code for transcription
58
+ **kwargs: Additional parameters for transcription
59
+
60
+ Returns:
61
+ str: Transcribed text
62
+ """
63
+ if not self._initialized:
64
+ if not self.initialize():
65
+ raise RuntimeError("Failed to initialize Faster Whisper model")
66
+
67
+ # Set default transcription parameters
68
+ transcribe_kwargs = {
69
+ "beam_size": 5,
70
+ "language": language,
71
+ "task": "transcribe"
72
+ }
73
+
74
+ # Update with any user-provided kwargs
75
+ transcribe_kwargs.update(kwargs)
76
+
77
+ # Transcribe audio
78
+ segments, info = self._model.transcribe(audio_path, **transcribe_kwargs)
79
+
80
+ # Combine all segments into a single text
81
+ transcription = " ".join([segment.text for segment in segments])
82
+
83
+ return transcription.strip()
84
+
85
+ def supports_multilingual(self):
86
+ return True
87
+
88
+ def get_supported_languages(self):
89
+ # Whisper supports many languages, but we'll return a subset of common ones
90
+ return [
91
+ "English", "Spanish", "French", "German", "Chinese", "Japanese",
92
+ "Russian", "Portuguese", "Italian", "Dutch", "Arabic", "Korean"
93
+ ]
src/models/tts/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .chatterbox_model import ChatterboxTTSModel
2
+ from .kitten_model import KittenTTSModel
3
+ from .piper_model import PiperTTSModel
4
+ from .kokoro_model import KokoroTTSModel
5
+ from .dia_model import DiaTTSModel
6
+
7
+ __all__ = [
8
+ 'ChatterboxTTSModel',
9
+ 'KittenTTSModel',
10
+ 'PiperTTSModel',
11
+ 'KokoroTTSModel',
12
+ 'DiaTTSModel'
13
+ ]
src/models/tts/chatterbox_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio as ta
3
+ import tempfile
4
+ import os
5
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
6
+ from ..base import TTSModel
7
+
8
+ class ChatterboxTTSModel(TTSModel):
9
+ """Chatterbox multilingual TTS model implementation"""
10
+
11
+ def __init__(self):
12
+ self._model = None
13
+ self._initialized = False
14
+
15
+ @property
16
+ def name(self):
17
+ return "ResembleAI/chatterbox"
18
+
19
+ @property
20
+ def description(self):
21
+ return "Industrial-grade TTS solution with multilingual support"
22
+
23
+ def initialize(self):
24
+ """Initialize the Chatterbox model"""
25
+ if self._initialized:
26
+ return True
27
+
28
+ try:
29
+ self._model = ChatterboxMultilingualTTS.from_pretrained(
30
+ device="cuda" if torch.cuda.is_available() else "cpu"
31
+ )
32
+ self._initialized = True
33
+ return True
34
+ except RuntimeError as e:
35
+ if "Attempting to deserialize object on a CUDA device" in str(e):
36
+ print("CUDA model detected but CUDA is not available. Loading model on CPU...")
37
+ self._model = ChatterboxMultilingualTTS.from_pretrained(device="cpu")
38
+ self._initialized = True
39
+ return True
40
+ else:
41
+ print(f"Error initializing Chatterbox model: {e}")
42
+ return False
43
+
44
+ def generate_speech(self, text, language="English", audio_prompt=None, **kwargs):
45
+ """
46
+ Generate speech from text using Chatterbox multilingual TTS
47
+
48
+ Args:
49
+ text (str): Text to convert to speech
50
+ language (str): Language name ('English' or 'Chinese')
51
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
52
+ **kwargs: Additional parameters for generation
53
+
54
+ Returns:
55
+ str: Path to the generated audio file
56
+ """
57
+ if not self._initialized:
58
+ if not self.initialize():
59
+ raise RuntimeError("Failed to initialize Chatterbox model")
60
+
61
+ # Map language names to language codes
62
+ language_map = {
63
+ "English": "en",
64
+ "Chinese": "zh"
65
+ }
66
+
67
+ language_id = language_map.get(language, "en")
68
+
69
+ # Default generation parameters
70
+ generate_kwargs = {
71
+ "exaggeration": 0.5,
72
+ "temperature": 0.8,
73
+ "cfg_weight": 0.3,
74
+ }
75
+
76
+ # Update with any user-provided kwargs
77
+ generate_kwargs.update(kwargs)
78
+
79
+ # Generate speech using Chatterbox
80
+ if audio_prompt and os.path.exists(audio_prompt):
81
+ # Use audio prompt for voice cloning
82
+ wav = self._model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs)
83
+ else:
84
+ # Generate without audio prompt (default voice)
85
+ wav = self._model.generate(text, language_id=language_id, **generate_kwargs)
86
+
87
+ # Save to a temporary file
88
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
89
+ ta.save(tmp_file.name, wav, self._model.sr)
90
+ return tmp_file.name
91
+
92
+ def supports_voice_cloning(self):
93
+ return True
94
+
95
+ def supports_multilingual(self):
96
+ return True
97
+
98
+ def get_supported_languages(self):
99
+ return ["English", "Chinese"]
src/models/tts/dia_model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ from ..base import TTSModel
4
+
5
+ class DiaTTSModel(TTSModel):
6
+ """Dia TTS model implementation"""
7
+
8
+ def __init__(self):
9
+ self._model = None
10
+ self._initialized = False
11
+
12
+ @property
13
+ def name(self):
14
+ return "nari-labs/Dia-1.6B"
15
+
16
+ @property
17
+ def description(self):
18
+ return "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions"
19
+
20
+ def initialize(self):
21
+ """Initialize the Dia model"""
22
+ if self._initialized:
23
+ return True
24
+
25
+ try:
26
+ # Import here to avoid circular imports
27
+ from src.dia_tts import DiaTTS
28
+ self._model = DiaTTS()
29
+ self._initialized = True
30
+ return True
31
+ except Exception as e:
32
+ print(f"Error initializing Dia model: {e}")
33
+ return False
34
+
35
+ def generate_speech(self, text, audio_prompt=None, **kwargs):
36
+ """
37
+ Generate speech from text using Dia TTS
38
+
39
+ Args:
40
+ text (str): Text to convert to speech
41
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
42
+ **kwargs: Additional parameters for generation
43
+
44
+ Returns:
45
+ str: Path to the generated audio file
46
+ """
47
+ if not self._initialized:
48
+ if not self.initialize():
49
+ raise RuntimeError("Failed to initialize Dia model")
50
+
51
+ # Generate speech using Dia
52
+ output_path = self._model.generate(text, reference_audio=audio_prompt, **kwargs)
53
+ return output_path
54
+
55
+ def supports_voice_cloning(self):
56
+ return True
src/models/tts/kitten_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ import soundfile as sf
4
+ import numpy as np
5
+ from kittentts import KittenTTS
6
+ from ..base import TTSModel
7
+
8
+ class KittenTTSModel(TTSModel):
9
+ """KittenTTS model implementation"""
10
+
11
+ def __init__(self):
12
+ self._model = None
13
+ self._initialized = False
14
+ self._model_path = "KittenML/kitten-tts-nano-0.2"
15
+
16
+ @property
17
+ def name(self):
18
+ return "KittenML/KittenTTS"
19
+
20
+ @property
21
+ def description(self):
22
+ return "High-quality TTS with voice cloning capabilities using reference audio"
23
+
24
+ def initialize(self):
25
+ """Initialize the KittenTTS model"""
26
+ if self._initialized:
27
+ return True
28
+
29
+ try:
30
+ self._model = KittenTTS(self._model_path)
31
+ self._initialized = True
32
+ return True
33
+ except Exception as e:
34
+ print(f"Error initializing KittenTTS model: {e}")
35
+ return False
36
+
37
+ def generate_speech(self, text, audio_prompt=None, **kwargs):
38
+ """
39
+ Generate speech from text using KittenTTS
40
+
41
+ Args:
42
+ text (str): Text to convert to speech
43
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
44
+ **kwargs: Additional parameters for generation
45
+
46
+ Returns:
47
+ str: Path to the generated audio file
48
+ """
49
+ if not self._initialized:
50
+ if not self.initialize():
51
+ raise RuntimeError("Failed to initialize KittenTTS model")
52
+
53
+ # Generate speech using KittenTTS
54
+ if audio_prompt and os.path.exists(audio_prompt):
55
+ # Use audio prompt for voice cloning
56
+ audio_array = self._model.generate_with_voice(text, audio_prompt)
57
+ else:
58
+ # Generate with default voice
59
+ audio_array = self._model.generate(text)
60
+
61
+ # Save to a temporary file
62
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
63
+ sf.write(tmp_file.name, audio_array, self._model.sample_rate)
64
+ return tmp_file.name
65
+
66
+ def supports_voice_cloning(self):
67
+ return True
src/models/tts/kokoro_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ from kokoro import KPipeline
4
+ from ..base import TTSModel
5
+
6
+ class KokoroTTSModel(TTSModel):
7
+ """Kokoro TTS model implementation"""
8
+
9
+ def __init__(self):
10
+ self._model = None
11
+ self._initialized = False
12
+ self._lang_code = 'a' # Default to American English
13
+
14
+ @property
15
+ def name(self):
16
+ return "hexgrad/kokoro"
17
+
18
+ @property
19
+ def description(self):
20
+ return "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use"
21
+
22
+ def initialize(self):
23
+ """Initialize the Kokoro model"""
24
+ if self._initialized:
25
+ return True
26
+
27
+ try:
28
+ self._model = KPipeline(lang_code=self._lang_code)
29
+ self._initialized = True
30
+ return True
31
+ except Exception as e:
32
+ print(f"Error initializing Kokoro model: {e}")
33
+ return False
34
+
35
+ def generate_speech(self, text, lang_code=None, **kwargs):
36
+ """
37
+ Generate speech from text using Kokoro TTS
38
+
39
+ Args:
40
+ text (str): Text to convert to speech
41
+ lang_code (str, optional): Language code ('a' for American English, 'b' for British English)
42
+ **kwargs: Additional parameters for generation
43
+
44
+ Returns:
45
+ str: Path to the generated audio file
46
+ """
47
+ # Update language code if provided
48
+ if lang_code and lang_code != self._lang_code:
49
+ self._lang_code = lang_code
50
+ self._initialized = False
51
+
52
+ if not self._initialized:
53
+ if not self.initialize():
54
+ raise RuntimeError("Failed to initialize Kokoro model")
55
+
56
+ # Generate speech
57
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
58
+ self._model.tts_to_file(text, tmp_file.name)
59
+ return tmp_file.name
60
+
61
+ def get_supported_languages(self):
62
+ return ["American English", "British English"]
63
+
64
+ def get_language_codes(self):
65
+ """Get mapping of language names to language codes"""
66
+ return {
67
+ "American English": "a",
68
+ "British English": "b"
69
+ }
src/models/tts/piper_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from piper import PiperVoice
4
+ from ..base import TTSModel
5
+
6
+ class PiperTTSModel(TTSModel):
7
+ """Piper TTS model implementation"""
8
+
9
+ def __init__(self):
10
+ self._voices_by_lang = None
11
+ self._initialized = False
12
+
13
+ @property
14
+ def name(self):
15
+ return "piper-tts"
16
+
17
+ @property
18
+ def description(self):
19
+ return "Local on-device TTS with dynamic English and Chinese voice selection from Piper models"
20
+
21
+ def initialize(self):
22
+ """Initialize the Piper model by scanning available voices"""
23
+ if self._initialized:
24
+ return True
25
+
26
+ try:
27
+ self._voices_by_lang = self._scan_piper_voices()
28
+ self._initialized = True
29
+ return True
30
+ except Exception as e:
31
+ print(f"Error initializing Piper model: {e}")
32
+ return False
33
+
34
+ def _scan_piper_voices(self):
35
+ """Scan available Piper voices"""
36
+ voices_dir = "src/voices/piper_voices"
37
+ voices_by_lang = {'English': {}, 'Chinese': {}}
38
+
39
+ # Chinese: only huayan medium
40
+ chinese_path = os.path.join(voices_dir, "zh", "zh_CN", "huayan", "medium", "zh_CN-huayan-medium.onnx")
41
+ if os.path.exists(chinese_path):
42
+ voices_by_lang['Chinese']['huayan (zh_CN)'] = chinese_path
43
+
44
+ # English voices
45
+ en_dir = os.path.join(voices_dir, "en")
46
+ for root, dirs, files in os.walk(en_dir):
47
+ if len(root.split(os.sep)) < 5: # Skip if not deep enough
48
+ continue
49
+ parts = root.split(os.sep)
50
+ if len(parts) >= 5 and parts[-1] in ['medium', 'high']:
51
+ locale = parts[-3] # en_GB or en_US
52
+ voice_name = parts[-2] # alan, etc.
53
+ quality = parts[-1] # medium or high
54
+
55
+ for file in files:
56
+ if file.endswith('.onnx') and f"{locale}-{voice_name}-{quality}" in file:
57
+ path = os.path.join(root, file)
58
+ label = f"{voice_name} ({locale})"
59
+ # Prefer medium over high
60
+ if quality == 'medium' or label not in voices_by_lang['English']:
61
+ voices_by_lang['English'][label] = path
62
+ break # Assume one .onnx per dir
63
+
64
+ return voices_by_lang
65
+
66
+ def generate_speech(self, text, language="English", voice=None, **kwargs):
67
+ """
68
+ Generate speech from text using Piper TTS
69
+
70
+ Args:
71
+ text (str): Text to convert to speech
72
+ language (str): Language name ('English' or 'Chinese')
73
+ voice (str, optional): Voice name to use
74
+ **kwargs: Additional parameters for generation
75
+
76
+ Returns:
77
+ str: Path to the generated audio file
78
+ """
79
+ if not self._initialized:
80
+ if not self.initialize():
81
+ raise RuntimeError("Failed to initialize Piper model")
82
+
83
+ # Get available voices for the selected language
84
+ available_voices = self._voices_by_lang.get(language, {})
85
+ if not available_voices:
86
+ raise ValueError(f"No voices available for language: {language}")
87
+
88
+ # If voice not specified or not available, use the first available voice
89
+ if not voice or voice not in available_voices:
90
+ voice = next(iter(available_voices.keys()))
91
+
92
+ # Get the model path for the selected voice
93
+ model_path = available_voices[voice]
94
+
95
+ # Create a PiperVoice instance for the selected voice
96
+ piper_voice = PiperVoice(model_path=model_path)
97
+
98
+ # Generate speech
99
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
100
+ piper_voice.synthesize(text, tmp_file.name)
101
+ return tmp_file.name
102
+
103
+ def supports_multilingual(self):
104
+ return True
105
+
106
+ def get_supported_languages(self):
107
+ if not self._initialized:
108
+ self.initialize()
109
+ return list(self._voices_by_lang.keys())
110
+
111
+ def get_available_voices(self, language="English"):
112
+ """Get available voices for a specific language"""
113
+ if not self._initialized:
114
+ self.initialize()
115
+ return list(self._voices_by_lang.get(language, {}).keys())