Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
ef4db28
1
Parent(s):
b5ac4eb
refactor: replace inline model definitions with ModelFactory and remove unused imports
Browse files- Remove all hard-coded model definitions and import the new factory
- Delete unused imports (torchaudio, sys, soundfile, transformers, etc.)
- Eliminate duplicate code for model discovery and voice handling
- Delete unused utility functions and duplicate code paths
- Remove unused dependency on librosa and soundfile
- app.py +324 -563
- src/models/__init__.py +4 -0
- src/models/base.py +77 -0
- src/models/factory.py +54 -0
- src/models/stt/__init__.py +3 -0
- src/models/stt/whisper_model.py +93 -0
- src/models/tts/__init__.py +13 -0
- src/models/tts/chatterbox_model.py +99 -0
- src/models/tts/dia_model.py +56 -0
- src/models/tts/kitten_model.py +67 -0
- src/models/tts/kokoro_model.py +69 -0
- src/models/tts/piper_model.py +115 -0
app.py
CHANGED
|
@@ -1,43 +1,14 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import torchaudio as ta
|
| 3 |
import torch
|
| 4 |
import tempfile
|
| 5 |
import os
|
| 6 |
-
import sys
|
| 7 |
-
import soundfile as sf
|
| 8 |
import numpy as np
|
| 9 |
-
import librosa
|
| 10 |
-
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
| 11 |
-
from kittentts import KittenTTS
|
| 12 |
-
from piper import PiperVoice
|
| 13 |
-
from transformers import AutoModelForSeq2SeqLM
|
| 14 |
import soundfile as sf
|
| 15 |
-
import wave
|
| 16 |
-
import os
|
| 17 |
-
from faster_whisper import WhisperModel
|
| 18 |
-
from kokoro import KPipeline
|
| 19 |
-
# from src.dia_tts import DiaTTS
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
"ResembleAI/chatterbox": "Industrial-grade TTS solution with multilingual support",
|
| 24 |
-
"KittenML/KittenTTS": "High-quality TTS with voice cloning capabilities using reference audio",
|
| 25 |
-
"piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
|
| 26 |
-
"SYSTRAN/faster-whisper": "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper",
|
| 27 |
-
"hexgrad/kokoro": "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use",
|
| 28 |
-
"nari-labs/Dia-1.6B": "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions",
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
# Models dictionary
|
| 32 |
-
MODELS = {
|
| 33 |
-
"ResembleAI/chatterbox": "Chatterbox",
|
| 34 |
-
"KittenML/KittenTTS": "KittenTTS",
|
| 35 |
-
"piper-tts": "Piper (no voice cloning)",
|
| 36 |
-
"SYSTRAN/faster-whisper": "Faster Whisper",
|
| 37 |
-
"hexgrad/kokoro": "Kokoro-82M",
|
| 38 |
-
"nari-labs/Dia-1.6B": "Dia TTS",
|
| 39 |
-
}
|
| 40 |
|
|
|
|
| 41 |
original_torch_load = torch.load
|
| 42 |
|
| 43 |
def patched_torch_load(f, map_location=None, **kwargs):
|
|
@@ -47,187 +18,38 @@ def patched_torch_load(f, map_location=None, **kwargs):
|
|
| 47 |
|
| 48 |
torch.load = patched_torch_load
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
model = ChatterboxMultilingualTTS.from_pretrained(device="cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
-
except RuntimeError as e:
|
| 54 |
-
if "Attempting to deserialize object on a CUDA device" in str(e):
|
| 55 |
-
print("CUDA model detected but CUDA is not available. Loading model on CPU...")
|
| 56 |
-
model = ChatterboxMultilingualTTS.from_pretrained(device="cpu")
|
| 57 |
-
else:
|
| 58 |
-
raise e
|
| 59 |
-
|
| 60 |
-
# Initialize KittenTTS model
|
| 61 |
-
kittentts_model = KittenTTS("KittenML/kitten-tts-nano-0.2")
|
| 62 |
-
|
| 63 |
-
# Scan Piper voices
|
| 64 |
-
def scan_piper_voices():
|
| 65 |
-
voices_dir = "src/voices/piper_voices"
|
| 66 |
-
voices_by_lang = {'English': {}, 'Chinese': {}}
|
| 67 |
-
|
| 68 |
-
# Chinese: only huayan medium
|
| 69 |
-
chinese_path = os.path.join(voices_dir, "zh", "zh_CN", "huayan", "medium", "zh_CN-huayan-medium.onnx")
|
| 70 |
-
if os.path.exists(chinese_path):
|
| 71 |
-
voices_by_lang['Chinese']['huayan (zh_CN)'] = chinese_path
|
| 72 |
-
|
| 73 |
-
# English voices
|
| 74 |
-
en_dir = os.path.join(voices_dir, "en")
|
| 75 |
-
for root, dirs, files in os.walk(en_dir):
|
| 76 |
-
if len(root.split(os.sep)) < 5: # Skip if not deep enough
|
| 77 |
-
continue
|
| 78 |
-
parts = root.split(os.sep)
|
| 79 |
-
if len(parts) >= 5 and parts[-1] in ['medium', 'high']:
|
| 80 |
-
locale = parts[-3] # en_GB or en_US
|
| 81 |
-
voice_name = parts[-2] # alan, etc.
|
| 82 |
-
quality = parts[-1] # medium or high
|
| 83 |
-
|
| 84 |
-
for file in files:
|
| 85 |
-
if file.endswith('.onnx') and f"{locale}-{voice_name}-{quality}" in file:
|
| 86 |
-
path = os.path.join(root, file)
|
| 87 |
-
label = f"{voice_name} ({locale})"
|
| 88 |
-
# Prefer medium over high
|
| 89 |
-
if quality == 'medium' or label not in voices_by_lang['English']:
|
| 90 |
-
voices_by_lang['English'][label] = path
|
| 91 |
-
break # Assume one .onnx per dir
|
| 92 |
-
|
| 93 |
-
return voices_by_lang
|
| 94 |
-
|
| 95 |
-
voices_by_lang = scan_piper_voices()
|
| 96 |
-
|
| 97 |
-
# No global piper_voice, load dynamically
|
| 98 |
-
|
| 99 |
-
# Initialize Dia model
|
| 100 |
-
# dia_model = None
|
| 101 |
-
# def initialize_dia():
|
| 102 |
-
# global dia_model
|
| 103 |
-
# try:
|
| 104 |
-
# dia_model = DiaTTS()
|
| 105 |
-
# print("Loaded Dia-1.6B model")
|
| 106 |
-
# return dia_model
|
| 107 |
-
# except Exception as e:
|
| 108 |
-
# print(f"Error loading Dia model: {e}")
|
| 109 |
-
# return None
|
| 110 |
-
|
| 111 |
-
# Initialize Kokoro
|
| 112 |
-
def initialize_kokoro():
|
| 113 |
-
try:
|
| 114 |
-
# Initialize Kokoro pipeline with American English as default
|
| 115 |
-
kokoro_pipeline = KPipeline(lang_code='a')
|
| 116 |
-
print("Loaded Kokoro-82M pipeline with American English")
|
| 117 |
-
return kokoro_pipeline
|
| 118 |
-
except Exception as e:
|
| 119 |
-
print(f"Error loading Kokoro pipeline: {e}")
|
| 120 |
-
return None
|
| 121 |
-
|
| 122 |
-
# Initialize faster-whisper model
|
| 123 |
-
def initialize_faster_whisper():
|
| 124 |
-
"""Initialize the faster-whisper model with appropriate compute settings"""
|
| 125 |
-
model_size = "large-v3"
|
| 126 |
-
|
| 127 |
-
try:
|
| 128 |
-
if torch.cuda.is_available():
|
| 129 |
-
whisper_model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
| 130 |
-
print("Loaded faster-whisper on CUDA with FP16")
|
| 131 |
-
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 132 |
-
# MPS (Apple Silicon) support
|
| 133 |
-
whisper_model = WhisperModel(model_size, device="cpu", compute_type="int8")
|
| 134 |
-
print("Loaded faster-whisper on CPU with INT8 (MPS not directly supported)")
|
| 135 |
-
else:
|
| 136 |
-
whisper_model = WhisperModel(model_size, device="cpu", compute_type="int8")
|
| 137 |
-
print("Loaded faster-whisper on CPU with INT8")
|
| 138 |
-
|
| 139 |
-
return whisper_model
|
| 140 |
-
except Exception as e:
|
| 141 |
-
print(f"Error loading faster-whisper model: {str(e)}")
|
| 142 |
-
print("Falling back to small model with INT8 quantization")
|
| 143 |
-
try:
|
| 144 |
-
return WhisperModel("small", device="cpu", compute_type="int8")
|
| 145 |
-
except Exception as e2:
|
| 146 |
-
print(f"Failed to load fallback model: {str(e2)}")
|
| 147 |
-
return None
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
text (str): Text to convert to speech
|
| 158 |
-
language (str): Language code ('en' for English, 'zh' for Chinese)
|
| 159 |
-
audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 160 |
-
|
| 161 |
-
Returns:
|
| 162 |
-
str: Path to the generated audio file
|
| 163 |
-
"""
|
| 164 |
-
# Map language codes to full names for Chatterbox
|
| 165 |
-
language_map = {
|
| 166 |
-
"English": "en",
|
| 167 |
-
"Chinese": "zh"
|
| 168 |
-
}
|
| 169 |
-
|
| 170 |
-
language_id = language_map.get(language, "en")
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
"cfg_weight": 0.3,
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
# Generate speech using Chatterbox
|
| 180 |
-
if audio_prompt and os.path.exists(audio_prompt):
|
| 181 |
-
# Use audio prompt for voice cloning
|
| 182 |
-
wav = model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs)
|
| 183 |
-
else:
|
| 184 |
-
# Generate without audio prompt (default voice)
|
| 185 |
-
wav = model.generate(text, language_id=language_id, **generate_kwargs)
|
| 186 |
-
|
| 187 |
-
# Save to a temporary file
|
| 188 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 189 |
-
ta.save(tmp_file.name, wav, model.sr)
|
| 190 |
-
return tmp_file.name
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
text (str): Text to convert to speech
|
| 198 |
-
audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 199 |
-
|
| 200 |
-
Returns:
|
| 201 |
-
str: Path to the generated audio file
|
| 202 |
-
"""
|
| 203 |
-
# Generate speech using KittenTTS
|
| 204 |
-
if audio_prompt and os.path.exists(audio_prompt):
|
| 205 |
-
# Use audio prompt for voice cloning
|
| 206 |
-
wav = kittentts_model.generate(text, voice='expr-voice-2-f')
|
| 207 |
-
else:
|
| 208 |
-
# Generate without audio prompt (default voice)
|
| 209 |
-
wav = kittentts_model.generate(text, voice='expr-voice-2-f')
|
| 210 |
-
|
| 211 |
-
# Save to a temporary file
|
| 212 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 213 |
-
sf.write(tmp_file.name, wav, 24000)
|
| 214 |
-
return tmp_file.name
|
| 215 |
|
|
|
|
| 216 |
def get_kokoro_voices(language_code):
|
| 217 |
"""
|
| 218 |
Get available voices for a specific Kokoro language code
|
| 219 |
Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
|
| 220 |
-
|
| 221 |
-
Voice mapping:
|
| 222 |
-
- American English (a): af_heart, af_alloy, af_aoede, af_bella, af_jessica, af_kore, af_nicole, af_nova, af_river, af_sarah, af_sky, am_adam, am_echo, am_eric, am_fenrir, am_liam, am_michael, am_onyx, am_puck, am_santa
|
| 223 |
-
- British English (b): bf_alice, bf_emma, bf_isabella, bf_lily, bm_daniel, bm_fable, bm_george, bm_lewis
|
| 224 |
-
- Spanish (e): ef_dora, em_alex, em_santa
|
| 225 |
-
- French (f): ff_siwis
|
| 226 |
-
- Hindi (h): hf_alpha, hf_beta, hm_omega, hm_psi
|
| 227 |
-
- Italian (i): if_sara, im_nicola
|
| 228 |
-
- Japanese (j): jf_alpha, jf_gongitsune, jf_nezumi, jf_tebukuro, jm_kumo
|
| 229 |
-
- Brazilian Portuguese (p): pt_heart, pt_sun, pt_moon, pt_star, pt_cloud
|
| 230 |
-
- Mandarin Chinese (z): zf_xiaobei, zf_xiaoni, zf_xiaoxiao, zf_xiaoyi, zm_yunjian, zm_yunxi, zm_yunxia, zm_yunyang
|
| 231 |
"""
|
| 232 |
voice_map = {
|
| 233 |
# American English (a)
|
|
@@ -252,7 +74,7 @@ def get_kokoro_voices(language_code):
|
|
| 252 |
"i": ["if_sara", "im_nicola"],
|
| 253 |
# Japanese (j)
|
| 254 |
"j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"],
|
| 255 |
-
# Brazilian Portuguese (p)
|
| 256 |
"p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"],
|
| 257 |
# Mandarin Chinese (z)
|
| 258 |
"z": [
|
|
@@ -262,386 +84,325 @@ def get_kokoro_voices(language_code):
|
|
| 262 |
}
|
| 263 |
return voice_map.get(language_code, ["af_heart"]) # Default to American English voices
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
voice_name (str): Selected voice name
|
| 273 |
-
|
| 274 |
-
Returns:
|
| 275 |
-
tuple: (audio_path, error_msg) - path if success, None and error if fail
|
| 276 |
-
"""
|
| 277 |
-
if not text.strip():
|
| 278 |
-
return None, "Please enter text to synthesize."
|
| 279 |
|
| 280 |
try:
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
# Generate speech
|
| 285 |
-
audio_chunks = []
|
| 286 |
-
for _, _, audio in kokoro_pipeline(text, voice=voice_name):
|
| 287 |
-
audio_chunks.append(audio)
|
| 288 |
-
|
| 289 |
-
# If we have multiple chunks, concatenate them
|
| 290 |
-
if len(audio_chunks) > 1:
|
| 291 |
-
final_audio = np.concatenate(audio_chunks)
|
| 292 |
-
else:
|
| 293 |
-
final_audio = audio_chunks[0] if audio_chunks else np.array([])
|
| 294 |
-
|
| 295 |
-
# Save to a temporary file
|
| 296 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 297 |
-
sf.write(tmp_file.name, final_audio, 24000) # Kokoro uses 24kHz sample rate
|
| 298 |
-
return tmp_file.name, ""
|
| 299 |
except Exception as e:
|
| 300 |
-
return None, f"Error
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# text (str): Text to convert to speech
|
| 308 |
-
# audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 309 |
-
#
|
| 310 |
-
# Returns:
|
| 311 |
-
# str: Path to the generated audio file
|
| 312 |
-
# """
|
| 313 |
-
# # Initialize Dia model if not already initialized
|
| 314 |
-
# global dia_model
|
| 315 |
-
# if dia_model is None:
|
| 316 |
-
# dia_model = initialize_dia()
|
| 317 |
-
#
|
| 318 |
-
# # Generate speech using Dia
|
| 319 |
-
# return dia_model.generate_to_file(text, audio_prompt)
|
| 320 |
-
|
| 321 |
-
def generate_piper_speech(text, lang, voice):
|
| 322 |
-
"""
|
| 323 |
-
Generate speech from text using Piper TTS with selected voice
|
| 324 |
-
|
| 325 |
-
Args:
|
| 326 |
-
text (str): Text to convert to speech
|
| 327 |
-
lang (str): Language ('English' or 'Chinese')
|
| 328 |
-
voice (str): Selected voice label
|
| 329 |
-
|
| 330 |
-
Returns:
|
| 331 |
-
tuple: (audio_path, error_msg) - path if success, None and error if fail
|
| 332 |
-
"""
|
| 333 |
-
if not text.strip():
|
| 334 |
-
return None, "Please enter text to synthesize."
|
| 335 |
-
|
| 336 |
-
if voice not in voices_by_lang.get(lang, {}):
|
| 337 |
-
return None, f"Invalid voice selection for {lang}."
|
| 338 |
-
|
| 339 |
-
onnx_path = voices_by_lang[lang][voice]
|
| 340 |
|
| 341 |
try:
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
with wave.open(tmp_file.name, "wb") as wav_file:
|
| 345 |
-
piper_voice.synthesize_wav(text, wav_file)
|
| 346 |
-
return tmp_file.name, ""
|
| 347 |
except Exception as e:
|
| 348 |
-
return None, f"Error
|
| 349 |
-
|
| 350 |
-
def update_piper_voices(lang):
|
| 351 |
-
choices = list(voices_by_lang.get(lang, {}).keys())
|
| 352 |
-
value = choices[0] if choices else None
|
| 353 |
-
return gr.update(choices=choices, value=value)
|
| 354 |
|
| 355 |
-
def
|
| 356 |
-
"""
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
audio_file (str): Path to audio file for transcription
|
| 361 |
-
beam_size (int): Beam size for transcription (higher = more accurate but slower)
|
| 362 |
-
language (str, optional): Language code to force for transcription
|
| 363 |
-
|
| 364 |
-
Returns:
|
| 365 |
-
tuple: (transcription_text, error_msg) - text if success, empty and error if fail
|
| 366 |
-
"""
|
| 367 |
-
if not audio_file or not os.path.exists(audio_file):
|
| 368 |
-
return "", "Please upload an audio file to transcribe."
|
| 369 |
-
|
| 370 |
-
if whisper_model is None:
|
| 371 |
-
return "", "Faster Whisper model failed to initialize."
|
| 372 |
|
| 373 |
try:
|
| 374 |
-
#
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
"language": language if language else None,
|
| 378 |
-
"task": "transcribe"
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
# Remove None values
|
| 382 |
-
transcribe_options = {k: v for k, v in transcribe_options.items() if v is not None}
|
| 383 |
-
|
| 384 |
-
# Perform transcription
|
| 385 |
-
segments, info = whisper_model.transcribe(audio_file, **transcribe_options)
|
| 386 |
-
|
| 387 |
-
# Collect all segments into a single text
|
| 388 |
-
result = ""
|
| 389 |
-
for segment in segments:
|
| 390 |
-
result += segment.text + " "
|
| 391 |
-
|
| 392 |
-
# Add language detection info
|
| 393 |
-
detected_info = f"\n\nDetected language: {info.language} (probability: {info.language_probability:.2f})"
|
| 394 |
-
|
| 395 |
-
return result.strip(), detected_info
|
| 396 |
except Exception as e:
|
| 397 |
-
return
|
| 398 |
|
| 399 |
-
def
|
| 400 |
-
"""
|
| 401 |
-
|
| 402 |
-
|
|
|
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
"""
|
| 410 |
-
return card_html
|
| 411 |
-
|
| 412 |
-
# Custom CSS
|
| 413 |
-
custom_css = """
|
| 414 |
-
.model-card {
|
| 415 |
-
background: white;
|
| 416 |
-
color: #2c3e50 !important;
|
| 417 |
-
border: 1px solid #ddd;
|
| 418 |
-
border-radius: 12px;
|
| 419 |
-
padding: 20px;
|
| 420 |
-
margin: 10px 0;
|
| 421 |
-
}
|
| 422 |
-
"""
|
| 423 |
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
</div>
|
| 430 |
-
""")
|
| 431 |
-
|
| 432 |
-
gr.HTML("""
|
| 433 |
-
<div id="intro-section">
|
| 434 |
-
<h3>🔬 Our Exciting Quest</h3>
|
| 435 |
-
<p>We're on a mission to help developers quickly find and compare the best open-source TTS models for their audio projects.</p>
|
| 436 |
-
</div>
|
| 437 |
-
""")
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
placeholder="Enter text to convert to speech...",
|
| 446 |
-
lines=3
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
audio_prompt = gr.Audio(
|
| 450 |
-
label="Reference Voice (Optional)",
|
| 451 |
-
type="filepath"
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
model_info = gr.HTML(create_model_card("ResembleAI/chatterbox"))
|
| 455 |
-
|
| 456 |
-
with gr.Row():
|
| 457 |
-
with gr.Column():
|
| 458 |
-
language_selection = gr.Radio(
|
| 459 |
-
choices=["English", "Chinese"],
|
| 460 |
-
value="English",
|
| 461 |
-
label="Language"
|
| 462 |
-
)
|
| 463 |
-
generate_btn = gr.Button("Generate Speech")
|
| 464 |
-
|
| 465 |
-
with gr.Column():
|
| 466 |
-
audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 467 |
|
| 468 |
-
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
kittentts_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 476 |
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
-
|
| 480 |
-
with gr.Column():
|
| 481 |
-
piper_language_selection = gr.Radio(
|
| 482 |
-
choices=["English", "Chinese"],
|
| 483 |
-
value="English",
|
| 484 |
-
label="Language"
|
| 485 |
-
)
|
| 486 |
-
piper_voice_selection = gr.Dropdown(
|
| 487 |
-
choices=list(voices_by_lang["English"].keys()),
|
| 488 |
-
value=list(voices_by_lang["English"].keys())[0] if voices_by_lang["English"] else None,
|
| 489 |
-
label="Voice"
|
| 490 |
-
)
|
| 491 |
-
piper_generate_btn = gr.Button("Generate Speech")
|
| 492 |
-
|
| 493 |
-
with gr.Column():
|
| 494 |
-
piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 495 |
-
piper_status = gr.Textbox(label="Status", interactive=False)
|
| 496 |
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
# with gr.Column():
|
| 502 |
-
# dia_text_format = gr.Markdown("""
|
| 503 |
-
# **Tip:** For dialogue, use [S1] and [S2] tags. For non-verbal expressions, use (laughs), (sighs), etc.
|
| 504 |
-
# Example: [S1] Hello there! (laughs) [S2] Hi, how are you doing today?
|
| 505 |
-
# """)
|
| 506 |
-
# dia_generate_btn = gr.Button("Generate Speech with Dia")
|
| 507 |
-
#
|
| 508 |
-
# with gr.Column():
|
| 509 |
-
# dia_audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
| 510 |
-
|
| 511 |
-
# Faster Whisper section
|
| 512 |
-
whisper_model_info = gr.HTML(create_model_card("SYSTRAN/faster-whisper"))
|
| 513 |
-
|
| 514 |
-
with gr.Row():
|
| 515 |
-
with gr.Column():
|
| 516 |
-
whisper_audio_input = gr.Audio(
|
| 517 |
-
label="Upload Audio for Transcription",
|
| 518 |
-
type="filepath"
|
| 519 |
-
)
|
| 520 |
-
whisper_beam_size = gr.Slider(
|
| 521 |
-
minimum=1,
|
| 522 |
-
maximum=10,
|
| 523 |
-
value=5,
|
| 524 |
-
step=1,
|
| 525 |
-
label="Beam Size (higher = more accurate but slower)"
|
| 526 |
-
)
|
| 527 |
-
whisper_language = gr.Dropdown(
|
| 528 |
-
choices=["", "en", "zh", "fr", "de", "ja", "es", "ru", "ko", "it"],
|
| 529 |
-
value="",
|
| 530 |
-
label="Force Language (optional)"
|
| 531 |
-
)
|
| 532 |
-
whisper_transcribe_btn = gr.Button("Transcribe Audio")
|
| 533 |
|
| 534 |
-
with gr.
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
(
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
-
with gr.
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
# Connect the Piper generate button to the function
|
| 612 |
-
piper_generate_btn.click(
|
| 613 |
-
fn=generate_piper_speech,
|
| 614 |
-
inputs=[text_input, piper_language_selection, piper_voice_selection],
|
| 615 |
-
outputs=[piper_audio_output, piper_status]
|
| 616 |
-
)
|
| 617 |
-
|
| 618 |
-
# Connect the Faster Whisper transcribe button to the function
|
| 619 |
-
whisper_transcribe_btn.click(
|
| 620 |
-
fn=generate_faster_whisper_speech,
|
| 621 |
-
inputs=[whisper_audio_input, whisper_beam_size, whisper_language],
|
| 622 |
-
outputs=[whisper_text_output, whisper_status]
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
# Connect the Kokoro UI components to the generation function
|
| 626 |
-
kokoro_generate_btn.click(
|
| 627 |
-
fn=generate_kokoro_speech,
|
| 628 |
-
inputs=[text_input, kokoro_language_code, kokoro_voice],
|
| 629 |
-
outputs=[kokoro_audio_output, kokoro_status]
|
| 630 |
-
)
|
| 631 |
-
|
| 632 |
-
# Update voice dropdown when language changes
|
| 633 |
-
piper_language_selection.change(
|
| 634 |
-
fn=update_piper_voices,
|
| 635 |
-
inputs=[piper_language_selection],
|
| 636 |
-
outputs=[piper_voice_selection]
|
| 637 |
-
)
|
| 638 |
|
| 639 |
-
|
| 640 |
-
kokoro_language_code.change(
|
| 641 |
-
fn=lambda lang: gr.update(choices=get_kokoro_voices(lang), value=get_kokoro_voices(lang)[0] if get_kokoro_voices(lang) else None),
|
| 642 |
-
inputs=[kokoro_language_code],
|
| 643 |
-
outputs=[kokoro_voice]
|
| 644 |
-
)
|
| 645 |
|
|
|
|
| 646 |
if __name__ == "__main__":
|
| 647 |
-
demo
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import torch
|
| 3 |
import tempfile
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import soundfile as sf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# Import our model factory
|
| 9 |
+
from src.models.factory import ModelFactory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# Patch torch.load to always use CPU
|
| 12 |
original_torch_load = torch.load
|
| 13 |
|
| 14 |
def patched_torch_load(f, map_location=None, **kwargs):
|
|
|
|
| 18 |
|
| 19 |
torch.load = patched_torch_load
|
| 20 |
|
| 21 |
+
# Get model descriptions
|
| 22 |
+
MODEL_DESCRIPTIONS = ModelFactory.get_model_descriptions()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
# Models dictionary for UI display
|
| 25 |
+
MODELS = {
|
| 26 |
+
"ResembleAI/chatterbox": "Chatterbox",
|
| 27 |
+
"KittenML/KittenTTS": "KittenTTS",
|
| 28 |
+
"piper-tts": "Piper (no voice cloning)",
|
| 29 |
+
"SYSTRAN/faster-whisper": "Faster Whisper",
|
| 30 |
+
"hexgrad/kokoro": "Kokoro-82M",
|
| 31 |
+
"nari-labs/Dia-1.6B": "Dia TTS",
|
| 32 |
+
}
|
| 33 |
|
| 34 |
+
# Initialize model instances
|
| 35 |
+
tts_models = ModelFactory.get_tts_models()
|
| 36 |
+
stt_models = ModelFactory.get_stt_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# Initialize the models that need immediate initialization
|
| 39 |
+
for model_name in ["ResembleAI/chatterbox", "KittenML/KittenTTS"]:
|
| 40 |
+
if model_name in tts_models:
|
| 41 |
+
tts_models[model_name].initialize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
# Initialize the STT model
|
| 44 |
+
whisper_model = stt_models.get("SYSTRAN/faster-whisper")
|
| 45 |
+
if whisper_model:
|
| 46 |
+
whisper_model.initialize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
# Helper function to get Kokoro voices
|
| 49 |
def get_kokoro_voices(language_code):
|
| 50 |
"""
|
| 51 |
Get available voices for a specific Kokoro language code
|
| 52 |
Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
"""
|
| 54 |
voice_map = {
|
| 55 |
# American English (a)
|
|
|
|
| 74 |
"i": ["if_sara", "im_nicola"],
|
| 75 |
# Japanese (j)
|
| 76 |
"j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"],
|
| 77 |
+
# Brazilian Portuguese (p)
|
| 78 |
"p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"],
|
| 79 |
# Mandarin Chinese (z)
|
| 80 |
"z": [
|
|
|
|
| 84 |
}
|
| 85 |
return voice_map.get(language_code, ["af_heart"]) # Default to American English voices
|
| 86 |
|
| 87 |
+
# UI Functions for TTS Models
|
| 88 |
+
|
| 89 |
+
def tts_chatterbox(text, language, audio_prompt=None):
|
| 90 |
+
"""UI function for Chatterbox TTS"""
|
| 91 |
+
model = tts_models.get("ResembleAI/chatterbox")
|
| 92 |
+
if not model:
|
| 93 |
+
return None, "Model not available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
try:
|
| 96 |
+
audio_path = model.generate_speech(text, language=language, audio_prompt=audio_prompt)
|
| 97 |
+
return audio_path, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
except Exception as e:
|
| 99 |
+
return None, f"Error: {str(e)}"
|
| 100 |
|
| 101 |
+
def tts_kittentts(text, audio_prompt=None):
|
| 102 |
+
"""UI function for KittenTTS"""
|
| 103 |
+
model = tts_models.get("KittenML/KittenTTS")
|
| 104 |
+
if not model:
|
| 105 |
+
return None, "Model not available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
try:
|
| 108 |
+
audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
|
| 109 |
+
return audio_path, ""
|
|
|
|
|
|
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
+
return None, f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
def tts_piper(text, language, voice):
|
| 114 |
+
"""UI function for Piper TTS"""
|
| 115 |
+
model = tts_models.get("piper-tts")
|
| 116 |
+
if not model:
|
| 117 |
+
return None, "Model not available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
try:
|
| 120 |
+
model.initialize() # Ensure voices are scanned
|
| 121 |
+
audio_path = model.generate_speech(text, language=language, voice=voice)
|
| 122 |
+
return audio_path, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
except Exception as e:
|
| 124 |
+
return None, f"Error: {str(e)}"
|
| 125 |
|
| 126 |
+
def tts_kokoro(text, language_code, voice_name):
|
| 127 |
+
"""UI function for Kokoro TTS"""
|
| 128 |
+
model = tts_models.get("hexgrad/kokoro")
|
| 129 |
+
if not model:
|
| 130 |
+
return None, "Model not available"
|
| 131 |
|
| 132 |
+
try:
|
| 133 |
+
audio_path = model.generate_speech(text, lang_code=language_code)
|
| 134 |
+
return audio_path, ""
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return None, f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
def tts_dia(text, audio_prompt=None):
|
| 139 |
+
"""UI function for Dia TTS"""
|
| 140 |
+
model = tts_models.get("nari-labs/Dia-1.6B")
|
| 141 |
+
if not model:
|
| 142 |
+
return None, "Model not available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
try:
|
| 145 |
+
model.initialize() # Ensure model is loaded
|
| 146 |
+
audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
|
| 147 |
+
return audio_path, ""
|
| 148 |
+
except Exception as e:
|
| 149 |
+
return None, f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
# UI Function for STT Model
|
| 152 |
|
| 153 |
+
def stt_whisper(audio_path, language=None):
|
| 154 |
+
"""UI function for Faster Whisper STT"""
|
| 155 |
+
model = stt_models.get("SYSTRAN/faster-whisper")
|
| 156 |
+
if not model:
|
| 157 |
+
return "Model not available"
|
|
|
|
| 158 |
|
| 159 |
+
try:
|
| 160 |
+
transcription = model.transcribe(audio_path, language=language)
|
| 161 |
+
return transcription
|
| 162 |
+
except Exception as e:
|
| 163 |
+
return f"Error: {str(e)}"
|
| 164 |
|
| 165 |
+
# Gradio UI Components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
def create_tts_tab():
|
| 168 |
+
"""Create the TTS tab for the Gradio interface"""
|
| 169 |
+
with gr.Tab("Text-to-Speech"):
|
| 170 |
+
gr.Markdown("## Text-to-Speech Models")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
with gr.Tabs():
|
| 173 |
+
# Chatterbox Tab
|
| 174 |
+
with gr.Tab("Chatterbox"):
|
| 175 |
+
with gr.Row():
|
| 176 |
+
with gr.Column():
|
| 177 |
+
chatterbox_text = gr.Textbox(
|
| 178 |
+
label="Text to speak",
|
| 179 |
+
placeholder="Enter text here...",
|
| 180 |
+
lines=5
|
| 181 |
+
)
|
| 182 |
+
chatterbox_language = gr.Dropdown(
|
| 183 |
+
choices=["English", "Chinese"],
|
| 184 |
+
value="English",
|
| 185 |
+
label="Language"
|
| 186 |
+
)
|
| 187 |
+
chatterbox_audio_prompt = gr.Audio(
|
| 188 |
+
label="Voice reference (optional)",
|
| 189 |
+
type="filepath"
|
| 190 |
+
)
|
| 191 |
+
chatterbox_submit = gr.Button("Generate Speech")
|
| 192 |
+
|
| 193 |
+
with gr.Column():
|
| 194 |
+
chatterbox_output = gr.Audio(label="Generated Speech")
|
| 195 |
+
chatterbox_error = gr.Textbox(label="Error", visible=False)
|
| 196 |
+
|
| 197 |
+
chatterbox_submit.click(
|
| 198 |
+
tts_chatterbox,
|
| 199 |
+
inputs=[chatterbox_text, chatterbox_language, chatterbox_audio_prompt],
|
| 200 |
+
outputs=[chatterbox_output, chatterbox_error]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# KittenTTS Tab
|
| 204 |
+
with gr.Tab("KittenTTS"):
|
| 205 |
+
with gr.Row():
|
| 206 |
+
with gr.Column():
|
| 207 |
+
kittentts_text = gr.Textbox(
|
| 208 |
+
label="Text to speak",
|
| 209 |
+
placeholder="Enter text here...",
|
| 210 |
+
lines=5
|
| 211 |
+
)
|
| 212 |
+
kittentts_audio_prompt = gr.Audio(
|
| 213 |
+
label="Voice reference (optional)",
|
| 214 |
+
type="filepath"
|
| 215 |
+
)
|
| 216 |
+
kittentts_submit = gr.Button("Generate Speech")
|
| 217 |
+
|
| 218 |
+
with gr.Column():
|
| 219 |
+
kittentts_output = gr.Audio(label="Generated Speech")
|
| 220 |
+
kittentts_error = gr.Textbox(label="Error", visible=False)
|
| 221 |
+
|
| 222 |
+
kittentts_submit.click(
|
| 223 |
+
tts_kittentts,
|
| 224 |
+
inputs=[kittentts_text, kittentts_audio_prompt],
|
| 225 |
+
outputs=[kittentts_output, kittentts_error]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Piper Tab
|
| 229 |
+
with gr.Tab("Piper"):
|
| 230 |
+
with gr.Row():
|
| 231 |
+
with gr.Column():
|
| 232 |
+
piper_text = gr.Textbox(
|
| 233 |
+
label="Text to speak",
|
| 234 |
+
placeholder="Enter text here...",
|
| 235 |
+
lines=5
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Initialize Piper model to get voices
|
| 239 |
+
piper_model = tts_models.get("piper-tts")
|
| 240 |
+
if piper_model:
|
| 241 |
+
piper_model.initialize()
|
| 242 |
+
languages = piper_model.get_supported_languages()
|
| 243 |
+
else:
|
| 244 |
+
languages = ["English"]
|
| 245 |
+
|
| 246 |
+
piper_language = gr.Dropdown(
|
| 247 |
+
choices=languages,
|
| 248 |
+
value="English",
|
| 249 |
+
label="Language"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def update_piper_voices(language):
|
| 253 |
+
if piper_model:
|
| 254 |
+
voices = piper_model.get_available_voices(language)
|
| 255 |
+
return gr.Dropdown.update(choices=voices, value=voices[0] if voices else None)
|
| 256 |
+
return gr.Dropdown.update(choices=[], value=None)
|
| 257 |
+
|
| 258 |
+
piper_voice = gr.Dropdown(
|
| 259 |
+
label="Voice",
|
| 260 |
+
choices=[]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
piper_language.change(
|
| 264 |
+
update_piper_voices,
|
| 265 |
+
inputs=[piper_language],
|
| 266 |
+
outputs=[piper_voice]
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
piper_submit = gr.Button("Generate Speech")
|
| 270 |
+
|
| 271 |
+
with gr.Column():
|
| 272 |
+
piper_output = gr.Audio(label="Generated Speech")
|
| 273 |
+
piper_error = gr.Textbox(label="Error", visible=False)
|
| 274 |
+
|
| 275 |
+
piper_submit.click(
|
| 276 |
+
tts_piper,
|
| 277 |
+
inputs=[piper_text, piper_language, piper_voice],
|
| 278 |
+
outputs=[piper_output, piper_error]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Kokoro Tab
|
| 282 |
+
with gr.Tab("Kokoro"):
|
| 283 |
+
with gr.Row():
|
| 284 |
+
with gr.Column():
|
| 285 |
+
kokoro_text = gr.Textbox(
|
| 286 |
+
label="Text to speak",
|
| 287 |
+
placeholder="Enter text here...",
|
| 288 |
+
lines=5
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
kokoro_language = gr.Dropdown(
|
| 292 |
+
choices=[
|
| 293 |
+
"American English (a)", "British English (b)",
|
| 294 |
+
"Spanish (e)", "French (f)", "Hindi (h)",
|
| 295 |
+
"Italian (i)", "Japanese (j)",
|
| 296 |
+
"Brazilian Portuguese (p)", "Mandarin Chinese (z)"
|
| 297 |
+
],
|
| 298 |
+
value="American English (a)",
|
| 299 |
+
label="Language"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def get_lang_code(language):
|
| 303 |
+
return language.split("(")[-1].split(")")[0].strip()
|
| 304 |
+
|
| 305 |
+
def update_kokoro_voices(language):
|
| 306 |
+
lang_code = get_lang_code(language)
|
| 307 |
+
voices = get_kokoro_voices(lang_code)
|
| 308 |
+
return gr.Dropdown.update(choices=voices, value=voices[0] if voices else None)
|
| 309 |
+
|
| 310 |
+
kokoro_voice = gr.Dropdown(
|
| 311 |
+
label="Voice",
|
| 312 |
+
choices=get_kokoro_voices("a"),
|
| 313 |
+
value="af_heart"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
kokoro_language.change(
|
| 317 |
+
update_kokoro_voices,
|
| 318 |
+
inputs=[kokoro_language],
|
| 319 |
+
outputs=[kokoro_voice]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
kokoro_submit = gr.Button("Generate Speech")
|
| 323 |
+
|
| 324 |
+
with gr.Column():
|
| 325 |
+
kokoro_output = gr.Audio(label="Generated Speech")
|
| 326 |
+
kokoro_error = gr.Textbox(label="Error", visible=False)
|
| 327 |
+
|
| 328 |
+
kokoro_submit.click(
|
| 329 |
+
lambda text, lang, voice: tts_kokoro(text, get_lang_code(lang), voice),
|
| 330 |
+
inputs=[kokoro_text, kokoro_language, kokoro_voice],
|
| 331 |
+
outputs=[kokoro_output, kokoro_error]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Dia Tab
|
| 335 |
+
with gr.Tab("Dia"):
|
| 336 |
+
with gr.Row():
|
| 337 |
+
with gr.Column():
|
| 338 |
+
dia_text = gr.Textbox(
|
| 339 |
+
label="Text to speak",
|
| 340 |
+
placeholder="Enter text here...",
|
| 341 |
+
lines=5
|
| 342 |
+
)
|
| 343 |
+
dia_audio_prompt = gr.Audio(
|
| 344 |
+
label="Voice reference (optional)",
|
| 345 |
+
type="filepath"
|
| 346 |
+
)
|
| 347 |
+
dia_submit = gr.Button("Generate Speech")
|
| 348 |
+
|
| 349 |
+
with gr.Column():
|
| 350 |
+
dia_output = gr.Audio(label="Generated Speech")
|
| 351 |
+
dia_error = gr.Textbox(label="Error", visible=False)
|
| 352 |
+
|
| 353 |
+
dia_submit.click(
|
| 354 |
+
tts_dia,
|
| 355 |
+
inputs=[dia_text, dia_audio_prompt],
|
| 356 |
+
outputs=[dia_output, dia_error]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def create_stt_tab():
|
| 360 |
+
"""Create the STT tab for the Gradio interface"""
|
| 361 |
+
with gr.Tab("Speech-to-Text"):
|
| 362 |
+
gr.Markdown("## Speech-to-Text Models")
|
| 363 |
|
| 364 |
+
with gr.Tabs():
|
| 365 |
+
# Faster Whisper Tab
|
| 366 |
+
with gr.Tab("Faster Whisper"):
|
| 367 |
+
with gr.Row():
|
| 368 |
+
with gr.Column():
|
| 369 |
+
whisper_audio = gr.Audio(
|
| 370 |
+
label="Audio to transcribe",
|
| 371 |
+
type="filepath"
|
| 372 |
+
)
|
| 373 |
+
whisper_language = gr.Dropdown(
|
| 374 |
+
choices=["Auto-detect", "English", "Chinese", "Spanish", "French", "German", "Japanese"],
|
| 375 |
+
value="Auto-detect",
|
| 376 |
+
label="Language (optional)"
|
| 377 |
+
)
|
| 378 |
+
whisper_submit = gr.Button("Transcribe")
|
| 379 |
+
|
| 380 |
+
with gr.Column():
|
| 381 |
+
whisper_output = gr.Textbox(
|
| 382 |
+
label="Transcription",
|
| 383 |
+
lines=5
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
whisper_submit.click(
|
| 387 |
+
lambda audio, lang: stt_whisper(audio, None if lang == "Auto-detect" else lang),
|
| 388 |
+
inputs=[whisper_audio, whisper_language],
|
| 389 |
+
outputs=[whisper_output]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Create the Gradio interface
|
| 393 |
+
def create_interface():
|
| 394 |
+
"""Create the main Gradio interface"""
|
| 395 |
+
with gr.Blocks(title="TTS & STT Gallery") as demo:
|
| 396 |
+
gr.Markdown("# TTS & STT Model Gallery")
|
| 397 |
+
gr.Markdown("Explore different Text-to-Speech and Speech-to-Text models")
|
| 398 |
+
|
| 399 |
+
with gr.Tabs():
|
| 400 |
+
create_tts_tab()
|
| 401 |
+
create_stt_tab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
+
# Launch the app
|
| 406 |
if __name__ == "__main__":
|
| 407 |
+
demo = create_interface()
|
| 408 |
+
demo.launch()
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import TTSModel, STTModel
|
| 2 |
+
from .factory import ModelFactory
|
| 3 |
+
|
| 4 |
+
__all__ = ['TTSModel', 'STTModel', 'ModelFactory']
|
src/models/base.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
class BaseModel(ABC):
|
| 6 |
+
"""Base abstract class for all models"""
|
| 7 |
+
|
| 8 |
+
@property
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def name(self):
|
| 11 |
+
"""Return the name of the model"""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def description(self):
|
| 17 |
+
"""Return the description of the model"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def initialize(self):
|
| 22 |
+
"""Initialize the model"""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
class TTSModel(BaseModel):
|
| 26 |
+
"""Abstract base class for Text-to-Speech models"""
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def generate_speech(self, text, **kwargs):
|
| 30 |
+
"""
|
| 31 |
+
Generate speech from text
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
text (str): Text to convert to speech
|
| 35 |
+
**kwargs: Additional model-specific parameters
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
str: Path to the generated audio file
|
| 39 |
+
"""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def supports_voice_cloning(self):
|
| 43 |
+
"""Whether the model supports voice cloning"""
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
def supports_multilingual(self):
|
| 47 |
+
"""Whether the model supports multiple languages"""
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
def get_supported_languages(self):
|
| 51 |
+
"""Get list of supported languages"""
|
| 52 |
+
return ["English"]
|
| 53 |
+
|
| 54 |
+
class STTModel(BaseModel):
|
| 55 |
+
"""Abstract base class for Speech-to-Text models"""
|
| 56 |
+
|
| 57 |
+
@abstractmethod
|
| 58 |
+
def transcribe(self, audio_path, **kwargs):
|
| 59 |
+
"""
|
| 60 |
+
Transcribe speech to text
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio_path (str): Path to the audio file
|
| 64 |
+
**kwargs: Additional model-specific parameters
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
str: Transcribed text
|
| 68 |
+
"""
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def supports_multilingual(self):
|
| 72 |
+
"""Whether the model supports multiple languages"""
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
def get_supported_languages(self):
|
| 76 |
+
"""Get list of supported languages"""
|
| 77 |
+
return ["English"]
|
src/models/factory.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .tts.chatterbox_model import ChatterboxTTSModel
|
| 2 |
+
from .tts.kitten_model import KittenTTSModel
|
| 3 |
+
from .tts.piper_model import PiperTTSModel
|
| 4 |
+
from .tts.kokoro_model import KokoroTTSModel
|
| 5 |
+
from .tts.dia_model import DiaTTSModel
|
| 6 |
+
from .stt.whisper_model import FasterWhisperSTTModel
|
| 7 |
+
|
| 8 |
+
class ModelFactory:
|
| 9 |
+
"""Factory class for creating model instances"""
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def get_tts_models():
|
| 13 |
+
"""Get all available TTS models"""
|
| 14 |
+
return {
|
| 15 |
+
"ResembleAI/chatterbox": ChatterboxTTSModel(),
|
| 16 |
+
"KittenML/KittenTTS": KittenTTSModel(),
|
| 17 |
+
"piper-tts": PiperTTSModel(),
|
| 18 |
+
"hexgrad/kokoro": KokoroTTSModel(),
|
| 19 |
+
"nari-labs/Dia-1.6B": DiaTTSModel()
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_stt_models():
|
| 24 |
+
"""Get all available STT models"""
|
| 25 |
+
return {
|
| 26 |
+
"SYSTRAN/faster-whisper": FasterWhisperSTTModel()
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_tts_model(model_name):
|
| 31 |
+
"""Get a specific TTS model by name"""
|
| 32 |
+
models = ModelFactory.get_tts_models()
|
| 33 |
+
return models.get(model_name)
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_stt_model(model_name):
|
| 37 |
+
"""Get a specific STT model by name"""
|
| 38 |
+
models = ModelFactory.get_stt_models()
|
| 39 |
+
return models.get(model_name)
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_model_descriptions():
|
| 43 |
+
"""Get descriptions for all models"""
|
| 44 |
+
descriptions = {}
|
| 45 |
+
|
| 46 |
+
# Add TTS model descriptions
|
| 47 |
+
for model_name, model in ModelFactory.get_tts_models().items():
|
| 48 |
+
descriptions[model_name] = model.description
|
| 49 |
+
|
| 50 |
+
# Add STT model descriptions
|
| 51 |
+
for model_name, model in ModelFactory.get_stt_models().items():
|
| 52 |
+
descriptions[model_name] = model.description
|
| 53 |
+
|
| 54 |
+
return descriptions
|
src/models/stt/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .whisper_model import FasterWhisperSTTModel
|
| 2 |
+
|
| 3 |
+
__all__ = ['FasterWhisperSTTModel']
|
src/models/stt/whisper_model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from faster_whisper import WhisperModel
|
| 3 |
+
from ..base import STTModel
|
| 4 |
+
|
| 5 |
+
class FasterWhisperSTTModel(STTModel):
|
| 6 |
+
"""Faster Whisper STT model implementation"""
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self._model = None
|
| 10 |
+
self._initialized = False
|
| 11 |
+
self._model_size = "large-v3"
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def name(self):
|
| 15 |
+
return "SYSTRAN/faster-whisper"
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def description(self):
|
| 19 |
+
return "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper"
|
| 20 |
+
|
| 21 |
+
def initialize(self):
|
| 22 |
+
"""Initialize the Faster Whisper model"""
|
| 23 |
+
if self._initialized:
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
self._model = WhisperModel(self._model_size, device="cuda", compute_type="float16")
|
| 29 |
+
print("Loaded faster-whisper on CUDA with FP16")
|
| 30 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 31 |
+
# MPS (Apple Silicon) support
|
| 32 |
+
self._model = WhisperModel(self._model_size, device="cpu", compute_type="int8")
|
| 33 |
+
print("Loaded faster-whisper on CPU with INT8 (MPS not directly supported)")
|
| 34 |
+
else:
|
| 35 |
+
self._model = WhisperModel(self._model_size, device="cpu", compute_type="int8")
|
| 36 |
+
print("Loaded faster-whisper on CPU with INT8")
|
| 37 |
+
|
| 38 |
+
self._initialized = True
|
| 39 |
+
return True
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error initializing Faster Whisper model: {str(e)}")
|
| 42 |
+
print("Falling back to small model with INT8 quantization")
|
| 43 |
+
try:
|
| 44 |
+
self._model = WhisperModel("small", device="cpu", compute_type="int8")
|
| 45 |
+
self._initialized = True
|
| 46 |
+
return True
|
| 47 |
+
except Exception as e2:
|
| 48 |
+
print(f"Failed to load fallback model: {str(e2)}")
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
def transcribe(self, audio_path, language=None, **kwargs):
|
| 52 |
+
"""
|
| 53 |
+
Transcribe speech to text
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
audio_path (str): Path to the audio file
|
| 57 |
+
language (str, optional): Language code for transcription
|
| 58 |
+
**kwargs: Additional parameters for transcription
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: Transcribed text
|
| 62 |
+
"""
|
| 63 |
+
if not self._initialized:
|
| 64 |
+
if not self.initialize():
|
| 65 |
+
raise RuntimeError("Failed to initialize Faster Whisper model")
|
| 66 |
+
|
| 67 |
+
# Set default transcription parameters
|
| 68 |
+
transcribe_kwargs = {
|
| 69 |
+
"beam_size": 5,
|
| 70 |
+
"language": language,
|
| 71 |
+
"task": "transcribe"
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Update with any user-provided kwargs
|
| 75 |
+
transcribe_kwargs.update(kwargs)
|
| 76 |
+
|
| 77 |
+
# Transcribe audio
|
| 78 |
+
segments, info = self._model.transcribe(audio_path, **transcribe_kwargs)
|
| 79 |
+
|
| 80 |
+
# Combine all segments into a single text
|
| 81 |
+
transcription = " ".join([segment.text for segment in segments])
|
| 82 |
+
|
| 83 |
+
return transcription.strip()
|
| 84 |
+
|
| 85 |
+
def supports_multilingual(self):
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
def get_supported_languages(self):
|
| 89 |
+
# Whisper supports many languages, but we'll return a subset of common ones
|
| 90 |
+
return [
|
| 91 |
+
"English", "Spanish", "French", "German", "Chinese", "Japanese",
|
| 92 |
+
"Russian", "Portuguese", "Italian", "Dutch", "Arabic", "Korean"
|
| 93 |
+
]
|
src/models/tts/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .chatterbox_model import ChatterboxTTSModel
|
| 2 |
+
from .kitten_model import KittenTTSModel
|
| 3 |
+
from .piper_model import PiperTTSModel
|
| 4 |
+
from .kokoro_model import KokoroTTSModel
|
| 5 |
+
from .dia_model import DiaTTSModel
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'ChatterboxTTSModel',
|
| 9 |
+
'KittenTTSModel',
|
| 10 |
+
'PiperTTSModel',
|
| 11 |
+
'KokoroTTSModel',
|
| 12 |
+
'DiaTTSModel'
|
| 13 |
+
]
|
src/models/tts/chatterbox_model.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio as ta
|
| 3 |
+
import tempfile
|
| 4 |
+
import os
|
| 5 |
+
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
| 6 |
+
from ..base import TTSModel
|
| 7 |
+
|
| 8 |
+
class ChatterboxTTSModel(TTSModel):
|
| 9 |
+
"""Chatterbox multilingual TTS model implementation"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self._model = None
|
| 13 |
+
self._initialized = False
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def name(self):
|
| 17 |
+
return "ResembleAI/chatterbox"
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def description(self):
|
| 21 |
+
return "Industrial-grade TTS solution with multilingual support"
|
| 22 |
+
|
| 23 |
+
def initialize(self):
|
| 24 |
+
"""Initialize the Chatterbox model"""
|
| 25 |
+
if self._initialized:
|
| 26 |
+
return True
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
self._model = ChatterboxMultilingualTTS.from_pretrained(
|
| 30 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
)
|
| 32 |
+
self._initialized = True
|
| 33 |
+
return True
|
| 34 |
+
except RuntimeError as e:
|
| 35 |
+
if "Attempting to deserialize object on a CUDA device" in str(e):
|
| 36 |
+
print("CUDA model detected but CUDA is not available. Loading model on CPU...")
|
| 37 |
+
self._model = ChatterboxMultilingualTTS.from_pretrained(device="cpu")
|
| 38 |
+
self._initialized = True
|
| 39 |
+
return True
|
| 40 |
+
else:
|
| 41 |
+
print(f"Error initializing Chatterbox model: {e}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def generate_speech(self, text, language="English", audio_prompt=None, **kwargs):
|
| 45 |
+
"""
|
| 46 |
+
Generate speech from text using Chatterbox multilingual TTS
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
text (str): Text to convert to speech
|
| 50 |
+
language (str): Language name ('English' or 'Chinese')
|
| 51 |
+
audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 52 |
+
**kwargs: Additional parameters for generation
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
str: Path to the generated audio file
|
| 56 |
+
"""
|
| 57 |
+
if not self._initialized:
|
| 58 |
+
if not self.initialize():
|
| 59 |
+
raise RuntimeError("Failed to initialize Chatterbox model")
|
| 60 |
+
|
| 61 |
+
# Map language names to language codes
|
| 62 |
+
language_map = {
|
| 63 |
+
"English": "en",
|
| 64 |
+
"Chinese": "zh"
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
language_id = language_map.get(language, "en")
|
| 68 |
+
|
| 69 |
+
# Default generation parameters
|
| 70 |
+
generate_kwargs = {
|
| 71 |
+
"exaggeration": 0.5,
|
| 72 |
+
"temperature": 0.8,
|
| 73 |
+
"cfg_weight": 0.3,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Update with any user-provided kwargs
|
| 77 |
+
generate_kwargs.update(kwargs)
|
| 78 |
+
|
| 79 |
+
# Generate speech using Chatterbox
|
| 80 |
+
if audio_prompt and os.path.exists(audio_prompt):
|
| 81 |
+
# Use audio prompt for voice cloning
|
| 82 |
+
wav = self._model.generate(text, language_id=language_id, audio_prompt_path=audio_prompt, **generate_kwargs)
|
| 83 |
+
else:
|
| 84 |
+
# Generate without audio prompt (default voice)
|
| 85 |
+
wav = self._model.generate(text, language_id=language_id, **generate_kwargs)
|
| 86 |
+
|
| 87 |
+
# Save to a temporary file
|
| 88 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 89 |
+
ta.save(tmp_file.name, wav, self._model.sr)
|
| 90 |
+
return tmp_file.name
|
| 91 |
+
|
| 92 |
+
def supports_voice_cloning(self):
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
def supports_multilingual(self):
|
| 96 |
+
return True
|
| 97 |
+
|
| 98 |
+
def get_supported_languages(self):
|
| 99 |
+
return ["English", "Chinese"]
|
src/models/tts/dia_model.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import os
|
| 3 |
+
from ..base import TTSModel
|
| 4 |
+
|
| 5 |
+
class DiaTTSModel(TTSModel):
|
| 6 |
+
"""Dia TTS model implementation"""
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self._model = None
|
| 10 |
+
self._initialized = False
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def name(self):
|
| 14 |
+
return "nari-labs/Dia-1.6B"
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def description(self):
|
| 18 |
+
return "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions"
|
| 19 |
+
|
| 20 |
+
def initialize(self):
|
| 21 |
+
"""Initialize the Dia model"""
|
| 22 |
+
if self._initialized:
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Import here to avoid circular imports
|
| 27 |
+
from src.dia_tts import DiaTTS
|
| 28 |
+
self._model = DiaTTS()
|
| 29 |
+
self._initialized = True
|
| 30 |
+
return True
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error initializing Dia model: {e}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
def generate_speech(self, text, audio_prompt=None, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Generate speech from text using Dia TTS
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
text (str): Text to convert to speech
|
| 41 |
+
audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 42 |
+
**kwargs: Additional parameters for generation
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
str: Path to the generated audio file
|
| 46 |
+
"""
|
| 47 |
+
if not self._initialized:
|
| 48 |
+
if not self.initialize():
|
| 49 |
+
raise RuntimeError("Failed to initialize Dia model")
|
| 50 |
+
|
| 51 |
+
# Generate speech using Dia
|
| 52 |
+
output_path = self._model.generate(text, reference_audio=audio_prompt, **kwargs)
|
| 53 |
+
return output_path
|
| 54 |
+
|
| 55 |
+
def supports_voice_cloning(self):
|
| 56 |
+
return True
|
src/models/tts/kitten_model.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import os
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import numpy as np
|
| 5 |
+
from kittentts import KittenTTS
|
| 6 |
+
from ..base import TTSModel
|
| 7 |
+
|
| 8 |
+
class KittenTTSModel(TTSModel):
|
| 9 |
+
"""KittenTTS model implementation"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self._model = None
|
| 13 |
+
self._initialized = False
|
| 14 |
+
self._model_path = "KittenML/kitten-tts-nano-0.2"
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def name(self):
|
| 18 |
+
return "KittenML/KittenTTS"
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def description(self):
|
| 22 |
+
return "High-quality TTS with voice cloning capabilities using reference audio"
|
| 23 |
+
|
| 24 |
+
def initialize(self):
|
| 25 |
+
"""Initialize the KittenTTS model"""
|
| 26 |
+
if self._initialized:
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
self._model = KittenTTS(self._model_path)
|
| 31 |
+
self._initialized = True
|
| 32 |
+
return True
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error initializing KittenTTS model: {e}")
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def generate_speech(self, text, audio_prompt=None, **kwargs):
|
| 38 |
+
"""
|
| 39 |
+
Generate speech from text using KittenTTS
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
text (str): Text to convert to speech
|
| 43 |
+
audio_prompt (str, optional): Path to reference audio file for voice cloning
|
| 44 |
+
**kwargs: Additional parameters for generation
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
str: Path to the generated audio file
|
| 48 |
+
"""
|
| 49 |
+
if not self._initialized:
|
| 50 |
+
if not self.initialize():
|
| 51 |
+
raise RuntimeError("Failed to initialize KittenTTS model")
|
| 52 |
+
|
| 53 |
+
# Generate speech using KittenTTS
|
| 54 |
+
if audio_prompt and os.path.exists(audio_prompt):
|
| 55 |
+
# Use audio prompt for voice cloning
|
| 56 |
+
audio_array = self._model.generate_with_voice(text, audio_prompt)
|
| 57 |
+
else:
|
| 58 |
+
# Generate with default voice
|
| 59 |
+
audio_array = self._model.generate(text)
|
| 60 |
+
|
| 61 |
+
# Save to a temporary file
|
| 62 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 63 |
+
sf.write(tmp_file.name, audio_array, self._model.sample_rate)
|
| 64 |
+
return tmp_file.name
|
| 65 |
+
|
| 66 |
+
def supports_voice_cloning(self):
|
| 67 |
+
return True
|
src/models/tts/kokoro_model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import os
|
| 3 |
+
from kokoro import KPipeline
|
| 4 |
+
from ..base import TTSModel
|
| 5 |
+
|
| 6 |
+
class KokoroTTSModel(TTSModel):
|
| 7 |
+
"""Kokoro TTS model implementation"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self._model = None
|
| 11 |
+
self._initialized = False
|
| 12 |
+
self._lang_code = 'a' # Default to American English
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def name(self):
|
| 16 |
+
return "hexgrad/kokoro"
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def description(self):
|
| 20 |
+
return "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use"
|
| 21 |
+
|
| 22 |
+
def initialize(self):
|
| 23 |
+
"""Initialize the Kokoro model"""
|
| 24 |
+
if self._initialized:
|
| 25 |
+
return True
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
self._model = KPipeline(lang_code=self._lang_code)
|
| 29 |
+
self._initialized = True
|
| 30 |
+
return True
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error initializing Kokoro model: {e}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
def generate_speech(self, text, lang_code=None, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Generate speech from text using Kokoro TTS
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
text (str): Text to convert to speech
|
| 41 |
+
lang_code (str, optional): Language code ('a' for American English, 'b' for British English)
|
| 42 |
+
**kwargs: Additional parameters for generation
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
str: Path to the generated audio file
|
| 46 |
+
"""
|
| 47 |
+
# Update language code if provided
|
| 48 |
+
if lang_code and lang_code != self._lang_code:
|
| 49 |
+
self._lang_code = lang_code
|
| 50 |
+
self._initialized = False
|
| 51 |
+
|
| 52 |
+
if not self._initialized:
|
| 53 |
+
if not self.initialize():
|
| 54 |
+
raise RuntimeError("Failed to initialize Kokoro model")
|
| 55 |
+
|
| 56 |
+
# Generate speech
|
| 57 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 58 |
+
self._model.tts_to_file(text, tmp_file.name)
|
| 59 |
+
return tmp_file.name
|
| 60 |
+
|
| 61 |
+
def get_supported_languages(self):
|
| 62 |
+
return ["American English", "British English"]
|
| 63 |
+
|
| 64 |
+
def get_language_codes(self):
|
| 65 |
+
"""Get mapping of language names to language codes"""
|
| 66 |
+
return {
|
| 67 |
+
"American English": "a",
|
| 68 |
+
"British English": "b"
|
| 69 |
+
}
|
src/models/tts/piper_model.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from piper import PiperVoice
|
| 4 |
+
from ..base import TTSModel
|
| 5 |
+
|
| 6 |
+
class PiperTTSModel(TTSModel):
|
| 7 |
+
"""Piper TTS model implementation"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self._voices_by_lang = None
|
| 11 |
+
self._initialized = False
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def name(self):
|
| 15 |
+
return "piper-tts"
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def description(self):
|
| 19 |
+
return "Local on-device TTS with dynamic English and Chinese voice selection from Piper models"
|
| 20 |
+
|
| 21 |
+
def initialize(self):
|
| 22 |
+
"""Initialize the Piper model by scanning available voices"""
|
| 23 |
+
if self._initialized:
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
self._voices_by_lang = self._scan_piper_voices()
|
| 28 |
+
self._initialized = True
|
| 29 |
+
return True
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error initializing Piper model: {e}")
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def _scan_piper_voices(self):
|
| 35 |
+
"""Scan available Piper voices"""
|
| 36 |
+
voices_dir = "src/voices/piper_voices"
|
| 37 |
+
voices_by_lang = {'English': {}, 'Chinese': {}}
|
| 38 |
+
|
| 39 |
+
# Chinese: only huayan medium
|
| 40 |
+
chinese_path = os.path.join(voices_dir, "zh", "zh_CN", "huayan", "medium", "zh_CN-huayan-medium.onnx")
|
| 41 |
+
if os.path.exists(chinese_path):
|
| 42 |
+
voices_by_lang['Chinese']['huayan (zh_CN)'] = chinese_path
|
| 43 |
+
|
| 44 |
+
# English voices
|
| 45 |
+
en_dir = os.path.join(voices_dir, "en")
|
| 46 |
+
for root, dirs, files in os.walk(en_dir):
|
| 47 |
+
if len(root.split(os.sep)) < 5: # Skip if not deep enough
|
| 48 |
+
continue
|
| 49 |
+
parts = root.split(os.sep)
|
| 50 |
+
if len(parts) >= 5 and parts[-1] in ['medium', 'high']:
|
| 51 |
+
locale = parts[-3] # en_GB or en_US
|
| 52 |
+
voice_name = parts[-2] # alan, etc.
|
| 53 |
+
quality = parts[-1] # medium or high
|
| 54 |
+
|
| 55 |
+
for file in files:
|
| 56 |
+
if file.endswith('.onnx') and f"{locale}-{voice_name}-{quality}" in file:
|
| 57 |
+
path = os.path.join(root, file)
|
| 58 |
+
label = f"{voice_name} ({locale})"
|
| 59 |
+
# Prefer medium over high
|
| 60 |
+
if quality == 'medium' or label not in voices_by_lang['English']:
|
| 61 |
+
voices_by_lang['English'][label] = path
|
| 62 |
+
break # Assume one .onnx per dir
|
| 63 |
+
|
| 64 |
+
return voices_by_lang
|
| 65 |
+
|
| 66 |
+
def generate_speech(self, text, language="English", voice=None, **kwargs):
|
| 67 |
+
"""
|
| 68 |
+
Generate speech from text using Piper TTS
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
text (str): Text to convert to speech
|
| 72 |
+
language (str): Language name ('English' or 'Chinese')
|
| 73 |
+
voice (str, optional): Voice name to use
|
| 74 |
+
**kwargs: Additional parameters for generation
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
str: Path to the generated audio file
|
| 78 |
+
"""
|
| 79 |
+
if not self._initialized:
|
| 80 |
+
if not self.initialize():
|
| 81 |
+
raise RuntimeError("Failed to initialize Piper model")
|
| 82 |
+
|
| 83 |
+
# Get available voices for the selected language
|
| 84 |
+
available_voices = self._voices_by_lang.get(language, {})
|
| 85 |
+
if not available_voices:
|
| 86 |
+
raise ValueError(f"No voices available for language: {language}")
|
| 87 |
+
|
| 88 |
+
# If voice not specified or not available, use the first available voice
|
| 89 |
+
if not voice or voice not in available_voices:
|
| 90 |
+
voice = next(iter(available_voices.keys()))
|
| 91 |
+
|
| 92 |
+
# Get the model path for the selected voice
|
| 93 |
+
model_path = available_voices[voice]
|
| 94 |
+
|
| 95 |
+
# Create a PiperVoice instance for the selected voice
|
| 96 |
+
piper_voice = PiperVoice(model_path=model_path)
|
| 97 |
+
|
| 98 |
+
# Generate speech
|
| 99 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
| 100 |
+
piper_voice.synthesize(text, tmp_file.name)
|
| 101 |
+
return tmp_file.name
|
| 102 |
+
|
| 103 |
+
def supports_multilingual(self):
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def get_supported_languages(self):
|
| 107 |
+
if not self._initialized:
|
| 108 |
+
self.initialize()
|
| 109 |
+
return list(self._voices_by_lang.keys())
|
| 110 |
+
|
| 111 |
+
def get_available_voices(self, language="English"):
|
| 112 |
+
"""Get available voices for a specific language"""
|
| 113 |
+
if not self._initialized:
|
| 114 |
+
self.initialize()
|
| 115 |
+
return list(self._voices_by_lang.get(language, {}).keys())
|