Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
3ed3b5a
1
Parent(s):
31708ca
refator tts part
Browse files- utils/tts.py +86 -312
- utils/tts_base.py +152 -0
- utils/tts_dia.py +14 -115
- utils/tts_dummy.py +8 -22
- utils/tts_engines.py +322 -0
- utils/tts_factory.py +118 -0
utils/tts.py
CHANGED
|
@@ -1,85 +1,52 @@
|
|
| 1 |
-
import os
|
| 2 |
import logging
|
| 3 |
-
import time
|
| 4 |
-
import soundfile as sf
|
| 5 |
-
from gradio_client import Client
|
| 6 |
-
|
| 7 |
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
DIA_AVAILABLE = False
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
from kokoro import KPipeline
|
| 18 |
-
KOKORO_AVAILABLE = True
|
| 19 |
-
logger.info("Kokoro TTS engine is available")
|
| 20 |
-
except AttributeError as e:
|
| 21 |
-
# Specifically catch the EspeakWrapper.set_data_path error
|
| 22 |
-
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
|
| 23 |
-
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
|
| 24 |
-
else:
|
| 25 |
-
# Re-raise if it's a different error
|
| 26 |
-
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
|
| 27 |
-
raise
|
| 28 |
-
except ImportError:
|
| 29 |
-
logger.warning("Kokoro TTS engine is not available")
|
| 30 |
|
|
|
|
| 31 |
class TTSEngine:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def __init__(self, lang_code='z'):
|
| 33 |
-
"""Initialize TTS Engine
|
| 34 |
|
| 35 |
Args:
|
| 36 |
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
| 37 |
'j' for Japanese, 'z' for Mandarin Chinese)
|
| 38 |
-
Note: lang_code is only used for Kokoro, not for Dia
|
| 39 |
"""
|
| 40 |
-
logger.info("Initializing
|
| 41 |
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
| 42 |
-
self.engine_type = None
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
try:
|
| 47 |
-
self.pipeline = KPipeline(lang_code=lang_code)
|
| 48 |
-
self.engine_type = "kokoro"
|
| 49 |
-
logger.info("TTS engine successfully initialized with Kokoro")
|
| 50 |
-
except Exception as kokoro_err:
|
| 51 |
-
logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
|
| 52 |
-
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
| 53 |
-
logger.info("Will try to fall back to Dia TTS engine")
|
| 54 |
-
|
| 55 |
-
if KOKORO_SPACE_AVAILABLE:
|
| 56 |
-
logger.info(f"Using Kokoro FastAPI server as primary TTS engine with language code: {lang_code}")
|
| 57 |
-
try:
|
| 58 |
-
self.client = Client("Remsky/Kokoro-TTS-Zero")
|
| 59 |
-
self.engine_type = "kokoro_space"
|
| 60 |
-
logger.info("TTS engine successfully initialized with Kokoro FastAPI server")
|
| 61 |
-
except Exception as kokoro_err:
|
| 62 |
-
logger.error(f"Failed to initialize Kokoro space: {str(kokoro_err)}")
|
| 63 |
-
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
| 64 |
-
logger.info("Will try to fall back to Dia TTS engine")
|
| 65 |
-
|
| 66 |
-
# Try Dia if Kokoro is not available or failed to initialize
|
| 67 |
-
if self.engine_type is None and DIA_AVAILABLE:
|
| 68 |
-
logger.info("Using Dia as fallback TTS engine")
|
| 69 |
-
# For Dia, we don't need to initialize anything here
|
| 70 |
-
# The model will be lazy-loaded when needed
|
| 71 |
-
self.pipeline = None
|
| 72 |
-
self.client = None
|
| 73 |
-
self.engine_type = "dia"
|
| 74 |
-
logger.info("TTS engine initialized with Dia (lazy loading)")
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 82 |
self.engine_type = "dummy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 85 |
"""Generate speech from text using available TTS engine
|
|
@@ -87,272 +54,79 @@ class TTSEngine:
|
|
| 87 |
Args:
|
| 88 |
text (str): Input text to synthesize
|
| 89 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
| 90 |
-
Note: voice parameter is only used for Kokoro, not for Dia
|
| 91 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
| 92 |
-
Note: speed parameter is only used for Kokoro, not for Dia
|
| 93 |
|
| 94 |
Returns:
|
| 95 |
str: Path to the generated audio file
|
| 96 |
"""
|
| 97 |
-
logger.info(f"
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
# Create output directory if it doesn't exist
|
| 101 |
-
os.makedirs("temp/outputs", exist_ok=True)
|
| 102 |
-
|
| 103 |
-
# Generate unique output path
|
| 104 |
-
output_path = f"temp/outputs/output_{int(time.time())}.wav"
|
| 105 |
-
|
| 106 |
-
# Use the appropriate TTS engine based on availability
|
| 107 |
-
if self.engine_type == "kokoro":
|
| 108 |
-
# Use Kokoro for TTS generation
|
| 109 |
-
generator = self.pipeline(text, voice=voice, speed=speed)
|
| 110 |
-
for _, _, audio in generator:
|
| 111 |
-
logger.info(f"Saving Kokoro audio to {output_path}")
|
| 112 |
-
sf.write(output_path, audio, 24000)
|
| 113 |
-
break
|
| 114 |
-
elif self.engine_type == "kokoro_space":
|
| 115 |
-
# Use Kokoro FastAPI server for TTS generation
|
| 116 |
-
logger.info("Generating speech using Kokoro FastAPI server")
|
| 117 |
-
logger.info(f"text to generate speech on is: {text}")
|
| 118 |
-
try:
|
| 119 |
-
result = self.client.predict(
|
| 120 |
-
text=text,
|
| 121 |
-
voice_names='af_nova',
|
| 122 |
-
speed=speed,
|
| 123 |
-
api_name="/generate_speech_from_ui"
|
| 124 |
-
)
|
| 125 |
-
logger.info(f"Received audio from Kokoro FastAPI server: {result}")
|
| 126 |
-
except Exception as e:
|
| 127 |
-
logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
|
| 128 |
-
logger.error(f"Error type: {type(e).__name__}")
|
| 129 |
-
logger.info("Falling back to dummy audio generation")
|
| 130 |
-
elif self.engine_type == "dia":
|
| 131 |
-
# Use Dia for TTS generation
|
| 132 |
-
try:
|
| 133 |
-
logger.info("Attempting to use Dia TTS for speech generation")
|
| 134 |
-
# Import here to avoid circular imports
|
| 135 |
-
try:
|
| 136 |
-
logger.info("Importing Dia speech generation module")
|
| 137 |
-
from utils.tts_dia import generate_speech as dia_generate_speech
|
| 138 |
-
logger.info("Successfully imported Dia speech generation function")
|
| 139 |
-
except ImportError as import_err:
|
| 140 |
-
logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
|
| 141 |
-
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
| 142 |
-
raise
|
| 143 |
-
|
| 144 |
-
# Call Dia's generate_speech function
|
| 145 |
-
logger.info("Calling Dia's generate_speech function")
|
| 146 |
-
output_path = dia_generate_speech(text)
|
| 147 |
-
logger.info(f"Generated audio with Dia: {output_path}")
|
| 148 |
-
except ImportError as import_err:
|
| 149 |
-
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
| 150 |
-
logger.error("Falling back to dummy audio generation")
|
| 151 |
-
return self._generate_dummy_audio(output_path)
|
| 152 |
-
except Exception as dia_error:
|
| 153 |
-
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
| 154 |
-
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 155 |
-
logger.error("Falling back to dummy audio generation")
|
| 156 |
-
# Fall back to dummy audio if Dia fails
|
| 157 |
-
return self._generate_dummy_audio(output_path)
|
| 158 |
-
else:
|
| 159 |
-
# Generate dummy audio as fallback
|
| 160 |
-
return self._generate_dummy_audio(output_path)
|
| 161 |
-
|
| 162 |
-
logger.info(f"Audio generation complete: {output_path}")
|
| 163 |
-
return output_path
|
| 164 |
-
|
| 165 |
-
except Exception as e:
|
| 166 |
-
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
| 167 |
-
raise
|
| 168 |
-
|
| 169 |
-
def _generate_dummy_audio(self, output_path):
|
| 170 |
-
"""Generate a dummy audio file with a simple sine wave
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
output_path (str): Path to save the dummy audio file
|
| 174 |
-
|
| 175 |
-
Returns:
|
| 176 |
-
str: Path to the generated dummy audio file
|
| 177 |
-
"""
|
| 178 |
-
import numpy as np
|
| 179 |
-
sample_rate = 24000
|
| 180 |
-
duration = 3.0 # seconds
|
| 181 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 182 |
-
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
| 183 |
-
|
| 184 |
-
logger.info(f"Saving dummy audio to {output_path}")
|
| 185 |
-
sf.write(output_path, tone, sample_rate)
|
| 186 |
-
logger.info(f"Dummy audio generation complete: {output_path}")
|
| 187 |
-
return output_path
|
| 188 |
-
|
| 189 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
|
| 190 |
"""Generate speech from text and yield each segment
|
| 191 |
|
| 192 |
Args:
|
| 193 |
text (str): Input text to synthesize
|
| 194 |
-
voice (str): Voice ID to use
|
| 195 |
-
speed (float): Speech speed multiplier
|
| 196 |
|
| 197 |
Yields:
|
| 198 |
tuple: (sample_rate, audio_data) pairs for each segment
|
| 199 |
"""
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
# and then yield it as a single chunk
|
| 210 |
-
try:
|
| 211 |
-
logger.info("Attempting to use Dia TTS for speech streaming")
|
| 212 |
-
# Import here to avoid circular imports
|
| 213 |
-
try:
|
| 214 |
-
logger.info("Importing required modules for Dia streaming")
|
| 215 |
-
import torch
|
| 216 |
-
logger.info("PyTorch successfully imported for Dia streaming")
|
| 217 |
-
|
| 218 |
-
try:
|
| 219 |
-
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
| 220 |
-
logger.info("Successfully imported Dia model and sample rate")
|
| 221 |
-
except ImportError as import_err:
|
| 222 |
-
logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
|
| 223 |
-
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
| 224 |
-
raise
|
| 225 |
-
except ImportError as torch_err:
|
| 226 |
-
logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
|
| 227 |
-
raise
|
| 228 |
-
|
| 229 |
-
# Get the Dia model
|
| 230 |
-
logger.info("Getting Dia model instance")
|
| 231 |
-
try:
|
| 232 |
-
model = _get_model()
|
| 233 |
-
logger.info("Successfully obtained Dia model instance")
|
| 234 |
-
except Exception as model_err:
|
| 235 |
-
logger.error(f"Failed to get Dia model instance: {str(model_err)}")
|
| 236 |
-
logger.error(f"Error type: {type(model_err).__name__}")
|
| 237 |
-
raise
|
| 238 |
-
|
| 239 |
-
# Generate audio
|
| 240 |
-
logger.info("Generating audio with Dia model")
|
| 241 |
-
with torch.inference_mode():
|
| 242 |
-
output_audio_np = model.generate(
|
| 243 |
-
text,
|
| 244 |
-
max_tokens=None,
|
| 245 |
-
cfg_scale=3.0,
|
| 246 |
-
temperature=1.3,
|
| 247 |
-
top_p=0.95,
|
| 248 |
-
cfg_filter_top_k=35,
|
| 249 |
-
use_torch_compile=False,
|
| 250 |
-
verbose=False
|
| 251 |
-
)
|
| 252 |
-
|
| 253 |
-
if output_audio_np is not None:
|
| 254 |
-
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
| 255 |
-
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
| 256 |
-
else:
|
| 257 |
-
logger.warning("Dia model returned None for audio output")
|
| 258 |
-
logger.warning("Falling back to dummy audio stream")
|
| 259 |
-
# Fall back to dummy audio if Dia fails
|
| 260 |
-
yield from self._generate_dummy_audio_stream()
|
| 261 |
-
except ImportError as import_err:
|
| 262 |
-
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
| 263 |
-
logger.error("Falling back to dummy audio stream")
|
| 264 |
-
# Fall back to dummy audio if Dia fails
|
| 265 |
-
yield from self._generate_dummy_audio_stream()
|
| 266 |
-
except Exception as dia_error:
|
| 267 |
-
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
| 268 |
-
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 269 |
-
logger.error("Falling back to dummy audio stream")
|
| 270 |
-
# Fall back to dummy audio if Dia fails
|
| 271 |
-
yield from self._generate_dummy_audio_stream()
|
| 272 |
-
else:
|
| 273 |
-
# Generate dummy audio chunks as fallback
|
| 274 |
-
yield from self._generate_dummy_audio_stream()
|
| 275 |
-
|
| 276 |
-
except Exception as e:
|
| 277 |
-
logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
|
| 278 |
-
raise
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
def _generate_dummy_audio_stream(self):
|
| 281 |
-
"""Generate dummy audio chunks
|
| 282 |
|
| 283 |
Yields:
|
| 284 |
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
| 285 |
"""
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
# Create 3 chunks of dummy audio
|
| 291 |
-
for i in range(3):
|
| 292 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 293 |
-
freq = 440 + (i * 220) # Different frequency for each chunk
|
| 294 |
-
tone = np.sin(2 * np.pi * freq * t) * 0.3
|
| 295 |
-
yield sample_rate, tone
|
| 296 |
|
| 297 |
-
#
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
Args:
|
| 302 |
-
lang_code (str): Language code for the pipeline
|
| 303 |
-
|
| 304 |
-
Returns:
|
| 305 |
-
TTSEngine: Initialized TTS engine instance
|
| 306 |
-
"""
|
| 307 |
-
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
| 308 |
-
try:
|
| 309 |
-
import streamlit as st
|
| 310 |
-
logger.info("Streamlit detected, using cached TTS engine")
|
| 311 |
-
@st.cache_resource
|
| 312 |
-
def _get_engine():
|
| 313 |
-
logger.info("Creating cached TTS engine instance")
|
| 314 |
-
engine = TTSEngine(lang_code)
|
| 315 |
-
logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
|
| 316 |
-
return engine
|
| 317 |
-
|
| 318 |
-
engine = _get_engine()
|
| 319 |
-
logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
|
| 320 |
-
return engine
|
| 321 |
-
except ImportError:
|
| 322 |
-
logger.info("Streamlit not available, creating direct TTS engine instance")
|
| 323 |
-
engine = TTSEngine(lang_code)
|
| 324 |
-
logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
|
| 325 |
-
return engine
|
| 326 |
|
| 327 |
-
def
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
logger.error(f"Error type: {type(e).__name__}")
|
| 353 |
-
if hasattr(e, '__traceback__'):
|
| 354 |
-
tb = e.__traceback__
|
| 355 |
-
while tb.tb_next:
|
| 356 |
-
tb = tb.tb_next
|
| 357 |
-
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
| 358 |
-
raise
|
|
|
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
# Configure logging
|
| 4 |
logger = logging.getLogger(__name__)
|
| 5 |
|
| 6 |
+
# Import from the new factory pattern implementation
|
| 7 |
+
from utils.tts_factory import get_tts_engine, generate_speech, TTSFactory
|
| 8 |
+
from utils.tts_engines import get_available_engines
|
|
|
|
| 9 |
|
| 10 |
+
# For backward compatibility
|
| 11 |
+
from utils.tts_engines import KOKORO_AVAILABLE, KOKORO_SPACE_AVAILABLE, DIA_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Backward compatibility class
|
| 14 |
class TTSEngine:
|
| 15 |
+
"""Legacy TTSEngine class for backward compatibility
|
| 16 |
+
|
| 17 |
+
This class is maintained for backward compatibility with existing code.
|
| 18 |
+
New code should use the factory pattern implementation directly.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
def __init__(self, lang_code='z'):
|
| 22 |
+
"""Initialize TTS Engine using the factory pattern
|
| 23 |
|
| 24 |
Args:
|
| 25 |
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
| 26 |
'j' for Japanese, 'z' for Mandarin Chinese)
|
|
|
|
| 27 |
"""
|
| 28 |
+
logger.info("Initializing legacy TTSEngine wrapper")
|
| 29 |
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
|
|
|
| 30 |
|
| 31 |
+
# Create the appropriate engine using the factory
|
| 32 |
+
self._engine = TTSFactory.create_engine(lang_code=lang_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# Set engine_type for backward compatibility
|
| 35 |
+
engine_class = self._engine.__class__.__name__
|
| 36 |
+
if 'Kokoro' in engine_class and 'Space' in engine_class:
|
| 37 |
+
self.engine_type = "kokoro_space"
|
| 38 |
+
elif 'Kokoro' in engine_class:
|
| 39 |
+
self.engine_type = "kokoro"
|
| 40 |
+
elif 'Dia' in engine_class:
|
| 41 |
+
self.engine_type = "dia"
|
| 42 |
+
else:
|
| 43 |
self.engine_type = "dummy"
|
| 44 |
+
|
| 45 |
+
# Set pipeline and client attributes for backward compatibility
|
| 46 |
+
self.pipeline = getattr(self._engine, 'pipeline', None)
|
| 47 |
+
self.client = getattr(self._engine, 'client', None)
|
| 48 |
+
|
| 49 |
+
logger.info(f"Legacy TTSEngine wrapper initialized with engine type: {self.engine_type}")
|
| 50 |
|
| 51 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 52 |
"""Generate speech from text using available TTS engine
|
|
|
|
| 54 |
Args:
|
| 55 |
text (str): Input text to synthesize
|
| 56 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
|
|
|
| 57 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
|
|
|
| 58 |
|
| 59 |
Returns:
|
| 60 |
str: Path to the generated audio file
|
| 61 |
"""
|
| 62 |
+
logger.info(f"Legacy TTSEngine wrapper calling generate_speech for text length: {len(text)}")
|
| 63 |
+
return self._engine.generate_speech(text, voice, speed)
|
| 64 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
|
| 66 |
"""Generate speech from text and yield each segment
|
| 67 |
|
| 68 |
Args:
|
| 69 |
text (str): Input text to synthesize
|
| 70 |
+
voice (str): Voice ID to use
|
| 71 |
+
speed (float): Speech speed multiplier
|
| 72 |
|
| 73 |
Yields:
|
| 74 |
tuple: (sample_rate, audio_data) pairs for each segment
|
| 75 |
"""
|
| 76 |
+
logger.info(f"Legacy TTSEngine wrapper calling generate_speech_stream for text length: {len(text)}")
|
| 77 |
+
yield from self._engine.generate_speech_stream(text, voice, speed)
|
| 78 |
+
|
| 79 |
+
# For backward compatibility
|
| 80 |
+
def _generate_dummy_audio(self, output_path):
|
| 81 |
+
"""Generate a dummy audio file with a simple sine wave (backward compatibility)
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
output_path (str): Path to save the dummy audio file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
Returns:
|
| 87 |
+
str: Path to the generated dummy audio file
|
| 88 |
+
"""
|
| 89 |
+
from utils.tts_base import DummyTTSEngine
|
| 90 |
+
dummy_engine = DummyTTSEngine()
|
| 91 |
+
return dummy_engine.generate_speech("", "", 1.0)
|
| 92 |
+
|
| 93 |
+
# For backward compatibility
|
| 94 |
def _generate_dummy_audio_stream(self):
|
| 95 |
+
"""Generate dummy audio chunks (backward compatibility)
|
| 96 |
|
| 97 |
Yields:
|
| 98 |
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
| 99 |
"""
|
| 100 |
+
from utils.tts_base import DummyTTSEngine
|
| 101 |
+
dummy_engine = DummyTTSEngine()
|
| 102 |
+
yield from dummy_engine.generate_speech_stream("", "", 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
# Import the new implementations from tts_base
|
| 105 |
+
# These functions are already defined in tts_base.py and imported at the top of this file
|
| 106 |
+
# They are kept here as comments for reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
# def get_tts_engine(lang_code='a'):
|
| 109 |
+
# """Get or create TTS engine instance
|
| 110 |
+
#
|
| 111 |
+
# Args:
|
| 112 |
+
# lang_code (str): Language code for the pipeline
|
| 113 |
+
#
|
| 114 |
+
# Returns:
|
| 115 |
+
# TTSEngineBase: Initialized TTS engine instance
|
| 116 |
+
# """
|
| 117 |
+
# # Implementation moved to tts_base.py
|
| 118 |
+
# pass
|
| 119 |
+
|
| 120 |
+
# def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 121 |
+
# """Public interface for TTS generation
|
| 122 |
+
#
|
| 123 |
+
# Args:
|
| 124 |
+
# text (str): Input text to synthesize
|
| 125 |
+
# voice (str): Voice ID to use
|
| 126 |
+
# speed (float): Speech speed multiplier
|
| 127 |
+
#
|
| 128 |
+
# Returns:
|
| 129 |
+
# str: Path to generated audio file
|
| 130 |
+
# "\"""
|
| 131 |
+
# # Implementation moved to tts_base.py
|
| 132 |
+
# pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_base.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import logging
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import numpy as np
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from typing import Tuple, Generator, Optional
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class TTSEngineBase(ABC):
|
| 13 |
+
"""Base class for all TTS engines
|
| 14 |
+
|
| 15 |
+
This abstract class defines the interface that all TTS engines must implement.
|
| 16 |
+
It also provides common utility methods for file handling and audio generation.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, lang_code: str = 'z'):
|
| 20 |
+
"""Initialize the TTS engine
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
lang_code (str): Language code ('a' for US English, 'b' for British English,
|
| 24 |
+
'j' for Japanese, 'z' for Mandarin Chinese)
|
| 25 |
+
Note: Not all engines support all language codes
|
| 26 |
+
"""
|
| 27 |
+
self.lang_code = lang_code
|
| 28 |
+
logger.info(f"Initializing {self.__class__.__name__} with language code: {lang_code}")
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 32 |
+
"""Generate speech from text
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
text (str): Input text to synthesize
|
| 36 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
| 37 |
+
Note: Not all engines support all voices
|
| 38 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
| 39 |
+
Note: Not all engines support speed adjustment
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
str: Path to the generated audio file
|
| 43 |
+
"""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
| 47 |
+
"""Generate speech from text and yield each segment
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
text (str): Input text to synthesize
|
| 51 |
+
voice (str): Voice ID to use
|
| 52 |
+
speed (float): Speech speed multiplier
|
| 53 |
+
|
| 54 |
+
Yields:
|
| 55 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
| 56 |
+
"""
|
| 57 |
+
# Default implementation: generate full audio and yield as a single chunk
|
| 58 |
+
output_path = self.generate_speech(text, voice, speed)
|
| 59 |
+
audio_data, sample_rate = sf.read(output_path)
|
| 60 |
+
yield sample_rate, audio_data
|
| 61 |
+
|
| 62 |
+
def _create_output_dir(self) -> str:
|
| 63 |
+
"""Create output directory for audio files
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
str: Path to the output directory
|
| 67 |
+
"""
|
| 68 |
+
output_dir = "temp/outputs"
|
| 69 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 70 |
+
return output_dir
|
| 71 |
+
|
| 72 |
+
def _generate_output_path(self, prefix: str = "output") -> str:
|
| 73 |
+
"""Generate a unique output path for audio files
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
prefix (str): Prefix for the output filename
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
str: Path to the output file
|
| 80 |
+
"""
|
| 81 |
+
output_dir = self._create_output_dir()
|
| 82 |
+
timestamp = int(time.time())
|
| 83 |
+
return f"{output_dir}/{prefix}_{timestamp}.wav"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DummyTTSEngine(TTSEngineBase):
|
| 87 |
+
"""Dummy TTS engine that generates a simple sine wave
|
| 88 |
+
|
| 89 |
+
This engine is used as a fallback when no other engines are available.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, lang_code: str = 'z'):
|
| 93 |
+
super().__init__(lang_code)
|
| 94 |
+
logger.warning("Using dummy TTS implementation as no other engines are available")
|
| 95 |
+
|
| 96 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 97 |
+
"""Generate a dummy audio file with a simple sine wave
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
text (str): Input text (not used)
|
| 101 |
+
voice (str): Voice ID (not used)
|
| 102 |
+
speed (float): Speed multiplier (not used)
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
str: Path to the generated dummy audio file
|
| 106 |
+
"""
|
| 107 |
+
logger.info(f"Generating dummy speech for text length: {len(text)}")
|
| 108 |
+
|
| 109 |
+
# Generate unique output path
|
| 110 |
+
output_path = self._generate_output_path("dummy")
|
| 111 |
+
|
| 112 |
+
# Generate a simple sine wave
|
| 113 |
+
sample_rate = 24000
|
| 114 |
+
duration = 3.0 # seconds
|
| 115 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 116 |
+
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
| 117 |
+
|
| 118 |
+
# Save the audio file
|
| 119 |
+
logger.info(f"Saving dummy audio to {output_path}")
|
| 120 |
+
sf.write(output_path, tone, sample_rate)
|
| 121 |
+
logger.info(f"Dummy audio generation complete: {output_path}")
|
| 122 |
+
|
| 123 |
+
return output_path
|
| 124 |
+
|
| 125 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
| 126 |
+
"""Generate dummy audio chunks with simple sine waves
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
text (str): Input text (not used)
|
| 130 |
+
voice (str): Voice ID (not used)
|
| 131 |
+
speed (float): Speed multiplier (not used)
|
| 132 |
+
|
| 133 |
+
Yields:
|
| 134 |
+
tuple: (sample_rate, audio_data) pairs for each dummy segment
|
| 135 |
+
"""
|
| 136 |
+
logger.info(f"Generating dummy speech stream for text length: {len(text)}")
|
| 137 |
+
|
| 138 |
+
sample_rate = 24000
|
| 139 |
+
duration = 1.0 # seconds per chunk
|
| 140 |
+
|
| 141 |
+
# Create 3 chunks of dummy audio
|
| 142 |
+
for i in range(3):
|
| 143 |
+
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 144 |
+
freq = 440 + (i * 220) # Different frequency for each chunk
|
| 145 |
+
tone = np.sin(2 * np.pi * freq * t) * 0.3
|
| 146 |
+
yield sample_rate, tone
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Factory functionality moved to tts_factory.py to avoid circular imports
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Note: Backward compatibility functions moved to tts_factory.py
|
utils/tts_dia.py
CHANGED
|
@@ -68,6 +68,9 @@ def _get_model() -> Dia:
|
|
| 68 |
def generate_speech(text: str, language: str = "zh") -> str:
|
| 69 |
"""Public interface for TTS generation using Dia model
|
| 70 |
|
|
|
|
|
|
|
|
|
|
| 71 |
Args:
|
| 72 |
text (str): Input text to synthesize
|
| 73 |
language (str): Language code (not used in Dia model, kept for API compatibility)
|
|
@@ -75,122 +78,18 @@ def generate_speech(text: str, language: str = "zh") -> str:
|
|
| 75 |
Returns:
|
| 76 |
str: Path to the generated audio file
|
| 77 |
"""
|
| 78 |
-
logger.info(f"
|
| 79 |
-
logger.info(f"Text content (first 50 chars): {text[:50]}...")
|
| 80 |
-
|
| 81 |
-
# Create output directory if it doesn't exist
|
| 82 |
-
output_dir = "temp/outputs"
|
| 83 |
-
logger.info(f"Ensuring output directory exists: {output_dir}")
|
| 84 |
-
try:
|
| 85 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 86 |
-
logger.info(f"Output directory ready: {output_dir}")
|
| 87 |
-
except PermissionError as perm_err:
|
| 88 |
-
logger.error(f"Permission error creating output directory: {perm_err}")
|
| 89 |
-
# Fall back to dummy TTS
|
| 90 |
-
logger.info("Falling back to dummy TTS due to directory creation error")
|
| 91 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
| 92 |
-
return dummy_generate_speech(text, language)
|
| 93 |
-
except Exception as dir_err:
|
| 94 |
-
logger.error(f"Error creating output directory: {dir_err}")
|
| 95 |
-
# Fall back to dummy TTS
|
| 96 |
-
logger.info("Falling back to dummy TTS due to directory creation error")
|
| 97 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
| 98 |
-
return dummy_generate_speech(text, language)
|
| 99 |
-
|
| 100 |
-
# Generate unique output path
|
| 101 |
-
timestamp = int(time.time())
|
| 102 |
-
output_path = f"{output_dir}/output_{timestamp}.wav"
|
| 103 |
-
logger.info(f"Output will be saved to: {output_path}")
|
| 104 |
-
|
| 105 |
-
# Get the model
|
| 106 |
-
logger.info("Retrieving Dia model instance")
|
| 107 |
-
try:
|
| 108 |
-
model = _get_model()
|
| 109 |
-
logger.info("Successfully retrieved Dia model instance")
|
| 110 |
-
except Exception as model_err:
|
| 111 |
-
logger.error(f"Failed to get Dia model: {model_err}")
|
| 112 |
-
logger.error(f"Error type: {type(model_err).__name__}")
|
| 113 |
-
# Fall back to dummy TTS
|
| 114 |
-
logger.info("Falling back to dummy TTS due to model loading error")
|
| 115 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
| 116 |
-
return dummy_generate_speech(text, language)
|
| 117 |
|
| 118 |
-
#
|
| 119 |
-
|
| 120 |
-
start_time = time.time()
|
| 121 |
|
| 122 |
try:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
cfg_scale=3.0,
|
| 129 |
-
temperature=1.3,
|
| 130 |
-
top_p=0.95,
|
| 131 |
-
cfg_filter_top_k=35,
|
| 132 |
-
use_torch_compile=False, # Keep False for stability
|
| 133 |
-
verbose=False
|
| 134 |
-
)
|
| 135 |
-
logger.info("Model.generate() completed")
|
| 136 |
-
except RuntimeError as rt_err:
|
| 137 |
-
logger.error(f"Runtime error during generation: {rt_err}")
|
| 138 |
-
if "CUDA out of memory" in str(rt_err):
|
| 139 |
-
logger.error("CUDA out of memory error - consider reducing batch size or model size")
|
| 140 |
# Fall back to dummy TTS
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
return
|
| 144 |
-
except Exception as gen_err:
|
| 145 |
-
logger.error(f"Error during audio generation: {gen_err}")
|
| 146 |
-
logger.error(f"Error type: {type(gen_err).__name__}")
|
| 147 |
-
# Fall back to dummy TTS
|
| 148 |
-
logger.info("Falling back to dummy TTS due to error during generation")
|
| 149 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
| 150 |
-
return dummy_generate_speech(text, language)
|
| 151 |
-
|
| 152 |
-
end_time = time.time()
|
| 153 |
-
generation_time = end_time - start_time
|
| 154 |
-
logger.info(f"Generation finished in {generation_time:.2f} seconds")
|
| 155 |
-
|
| 156 |
-
# Process the output
|
| 157 |
-
if output_audio_np is not None:
|
| 158 |
-
logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
|
| 159 |
-
logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
|
| 160 |
-
|
| 161 |
-
# Apply a slight slowdown for better quality (0.94x speed)
|
| 162 |
-
speed_factor = 0.94
|
| 163 |
-
original_len = len(output_audio_np)
|
| 164 |
-
target_len = int(original_len / speed_factor)
|
| 165 |
-
|
| 166 |
-
logger.info(f"Applying speed adjustment factor: {speed_factor}")
|
| 167 |
-
if target_len != original_len and target_len > 0:
|
| 168 |
-
try:
|
| 169 |
-
x_original = np.arange(original_len)
|
| 170 |
-
x_resampled = np.linspace(0, original_len - 1, target_len)
|
| 171 |
-
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
|
| 172 |
-
logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
|
| 173 |
-
except Exception as resample_err:
|
| 174 |
-
logger.error(f"Error during audio resampling: {resample_err}")
|
| 175 |
-
logger.warning("Using original audio without resampling")
|
| 176 |
-
|
| 177 |
-
# Save the audio file
|
| 178 |
-
logger.info(f"Saving audio to file: {output_path}")
|
| 179 |
-
try:
|
| 180 |
-
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
| 181 |
-
logger.info(f"Audio successfully saved to {output_path}")
|
| 182 |
-
except Exception as save_err:
|
| 183 |
-
logger.error(f"Error saving audio file: {save_err}")
|
| 184 |
-
logger.error(f"Error type: {type(save_err).__name__}")
|
| 185 |
-
# Fall back to dummy TTS
|
| 186 |
-
logger.info("Falling back to dummy TTS due to error saving audio file")
|
| 187 |
-
from utils.tts_dummy import generate_speech as dummy_generate_speech
|
| 188 |
-
return dummy_generate_speech(text, language)
|
| 189 |
-
|
| 190 |
-
return output_path
|
| 191 |
-
else:
|
| 192 |
-
logger.warning("Generation produced no output (None returned from model)")
|
| 193 |
-
logger.warning("This may indicate a model configuration issue or empty input text")
|
| 194 |
-
dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
|
| 195 |
-
logger.warning(f"Returning dummy audio path: {dummy_path}")
|
| 196 |
-
return dummy_path
|
|
|
|
| 68 |
def generate_speech(text: str, language: str = "zh") -> str:
|
| 69 |
"""Public interface for TTS generation using Dia model
|
| 70 |
|
| 71 |
+
This is a legacy function maintained for backward compatibility.
|
| 72 |
+
New code should use the factory pattern implementation directly.
|
| 73 |
+
|
| 74 |
Args:
|
| 75 |
text (str): Input text to synthesize
|
| 76 |
language (str): Language code (not used in Dia model, kept for API compatibility)
|
|
|
|
| 78 |
Returns:
|
| 79 |
str: Path to the generated audio file
|
| 80 |
"""
|
| 81 |
+
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
# Use the new implementation via factory pattern
|
| 84 |
+
from utils.tts_engines import DiaTTSEngine
|
|
|
|
| 85 |
|
| 86 |
try:
|
| 87 |
+
# Create a Dia engine and generate speech
|
| 88 |
+
dia_engine = DiaTTSEngine(language)
|
| 89 |
+
return dia_engine.generate_speech(text)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Fall back to dummy TTS
|
| 93 |
+
from utils.tts_base import DummyTTSEngine
|
| 94 |
+
dummy_engine = DummyTTSEngine()
|
| 95 |
+
return dummy_engine.generate_speech(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_dummy.py
CHANGED
|
@@ -1,25 +1,11 @@
|
|
| 1 |
def generate_speech(text: str, language: str = "zh") -> str:
|
| 2 |
-
"""Public interface for TTS generation
|
| 3 |
-
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import soundfile as sf
|
| 6 |
-
import time
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# Generate a simple sine wave as dummy audio
|
| 17 |
-
sample_rate = 24000
|
| 18 |
-
duration = 2.0 # seconds
|
| 19 |
-
t = np.linspace(0, duration, int(sample_rate * duration), False)
|
| 20 |
-
tone = np.sin(2 * np.pi * 440 * t) * 0.3
|
| 21 |
-
|
| 22 |
-
# Save the audio file
|
| 23 |
-
sf.write(output_path, tone, sample_rate)
|
| 24 |
-
|
| 25 |
-
return output_path
|
|
|
|
| 1 |
def generate_speech(text: str, language: str = "zh") -> str:
|
| 2 |
+
"""Public interface for TTS generation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
This is a legacy function maintained for backward compatibility.
|
| 5 |
+
New code should use the factory pattern implementation directly.
|
| 6 |
+
"""
|
| 7 |
+
from utils.tts_base import DummyTTSEngine
|
| 8 |
|
| 9 |
+
# Create a dummy engine and generate speech
|
| 10 |
+
dummy_engine = DummyTTSEngine()
|
| 11 |
+
return dummy_engine.generate_speech(text, "af_heart", 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/tts_engines.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Generator, Any
|
| 7 |
+
|
| 8 |
+
from utils.tts_base import TTSEngineBase, DummyTTSEngine
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Flag to track TTS engine availability
|
| 14 |
+
KOKORO_AVAILABLE = False
|
| 15 |
+
KOKORO_SPACE_AVAILABLE = True
|
| 16 |
+
DIA_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
# Try to import Kokoro
|
| 19 |
+
try:
|
| 20 |
+
from kokoro import KPipeline
|
| 21 |
+
KOKORO_AVAILABLE = True
|
| 22 |
+
logger.info("Kokoro TTS engine is available")
|
| 23 |
+
except AttributeError as e:
|
| 24 |
+
# Specifically catch the EspeakWrapper.set_data_path error
|
| 25 |
+
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
|
| 26 |
+
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server")
|
| 27 |
+
else:
|
| 28 |
+
# Re-raise if it's a different error
|
| 29 |
+
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
|
| 30 |
+
raise
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("Kokoro TTS engine is not available")
|
| 33 |
+
|
| 34 |
+
# Try to import Dia dependencies to check availability
|
| 35 |
+
try:
|
| 36 |
+
import torch
|
| 37 |
+
from dia.model import Dia
|
| 38 |
+
DIA_AVAILABLE = True
|
| 39 |
+
logger.info("Dia TTS engine is available")
|
| 40 |
+
except ImportError:
|
| 41 |
+
logger.warning("Dia TTS engine is not available")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class KokoroTTSEngine(TTSEngineBase):
|
| 45 |
+
"""Kokoro TTS engine implementation
|
| 46 |
+
|
| 47 |
+
This engine uses the Kokoro library for TTS generation.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, lang_code: str = 'z'):
|
| 51 |
+
super().__init__(lang_code)
|
| 52 |
+
try:
|
| 53 |
+
self.pipeline = KPipeline(lang_code=lang_code)
|
| 54 |
+
logger.info("Kokoro TTS engine successfully initialized")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}")
|
| 57 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 58 |
+
raise
|
| 59 |
+
|
| 60 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 61 |
+
"""Generate speech using Kokoro TTS engine
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
text (str): Input text to synthesize
|
| 65 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
| 66 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: Path to the generated audio file
|
| 70 |
+
"""
|
| 71 |
+
logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
|
| 72 |
+
|
| 73 |
+
# Generate unique output path
|
| 74 |
+
output_path = self._generate_output_path()
|
| 75 |
+
|
| 76 |
+
# Generate speech
|
| 77 |
+
generator = self.pipeline(text, voice=voice, speed=speed)
|
| 78 |
+
for _, _, audio in generator:
|
| 79 |
+
logger.info(f"Saving Kokoro audio to {output_path}")
|
| 80 |
+
sf.write(output_path, audio, 24000)
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
logger.info(f"Kokoro audio generation complete: {output_path}")
|
| 84 |
+
return output_path
|
| 85 |
+
|
| 86 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
| 87 |
+
"""Generate speech stream using Kokoro TTS engine
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
text (str): Input text to synthesize
|
| 91 |
+
voice (str): Voice ID to use
|
| 92 |
+
speed (float): Speech speed multiplier
|
| 93 |
+
|
| 94 |
+
Yields:
|
| 95 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
| 96 |
+
"""
|
| 97 |
+
logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
|
| 98 |
+
|
| 99 |
+
# Generate speech stream
|
| 100 |
+
generator = self.pipeline(text, voice=voice, speed=speed)
|
| 101 |
+
for _, _, audio in generator:
|
| 102 |
+
yield 24000, audio
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class KokoroSpaceTTSEngine(TTSEngineBase):
|
| 106 |
+
"""Kokoro Space TTS engine implementation
|
| 107 |
+
|
| 108 |
+
This engine uses the Kokoro FastAPI server for TTS generation.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, lang_code: str = 'z'):
|
| 112 |
+
super().__init__(lang_code)
|
| 113 |
+
try:
|
| 114 |
+
from gradio_client import Client
|
| 115 |
+
self.client = Client("Remsky/Kokoro-TTS-Zero")
|
| 116 |
+
logger.info("Kokoro Space TTS engine successfully initialized")
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Failed to initialize Kokoro Space client: {str(e)}")
|
| 119 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 123 |
+
"""Generate speech using Kokoro Space TTS engine
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
text (str): Input text to synthesize
|
| 127 |
+
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
| 128 |
+
speed (float): Speech speed multiplier (0.5 to 2.0)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
str: Path to the generated audio file
|
| 132 |
+
"""
|
| 133 |
+
logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}")
|
| 134 |
+
logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}")
|
| 135 |
+
|
| 136 |
+
# Generate unique output path
|
| 137 |
+
output_path = self._generate_output_path()
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# Use af_nova as the default voice for Kokoro Space
|
| 141 |
+
voice_to_use = 'af_nova' if voice == 'af_heart' else voice
|
| 142 |
+
|
| 143 |
+
# Generate speech
|
| 144 |
+
result = self.client.predict(
|
| 145 |
+
text=text,
|
| 146 |
+
voice_names=voice_to_use,
|
| 147 |
+
speed=speed,
|
| 148 |
+
api_name="/generate_speech_from_ui"
|
| 149 |
+
)
|
| 150 |
+
logger.info(f"Received audio from Kokoro FastAPI server: {result}")
|
| 151 |
+
|
| 152 |
+
# TODO: Process the result and save to output_path
|
| 153 |
+
# For now, we'll return the result path directly if it's a string
|
| 154 |
+
if isinstance(result, str) and os.path.exists(result):
|
| 155 |
+
return result
|
| 156 |
+
else:
|
| 157 |
+
logger.warning("Unexpected result from Kokoro Space, falling back to dummy audio")
|
| 158 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
|
| 162 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 163 |
+
logger.info("Falling back to dummy audio generation")
|
| 164 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DiaTTSEngine(TTSEngineBase):
|
| 168 |
+
"""Dia TTS engine implementation
|
| 169 |
+
|
| 170 |
+
This engine uses the Dia model for TTS generation.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, lang_code: str = 'z'):
|
| 174 |
+
super().__init__(lang_code)
|
| 175 |
+
# Dia doesn't need initialization here, it will be lazy-loaded when needed
|
| 176 |
+
logger.info("Dia TTS engine initialized (lazy loading)")
|
| 177 |
+
|
| 178 |
+
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 179 |
+
"""Generate speech using Dia TTS engine
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
text (str): Input text to synthesize
|
| 183 |
+
voice (str): Voice ID (not used in Dia)
|
| 184 |
+
speed (float): Speech speed multiplier (not used in Dia)
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
str: Path to the generated audio file
|
| 188 |
+
"""
|
| 189 |
+
logger.info(f"Generating speech with Dia for text length: {len(text)}")
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
# Import here to avoid circular imports
|
| 193 |
+
from utils.tts_dia import generate_speech as dia_generate_speech
|
| 194 |
+
logger.info("Successfully imported Dia speech generation function")
|
| 195 |
+
|
| 196 |
+
# Call Dia's generate_speech function
|
| 197 |
+
# Note: Dia's function expects a language parameter, not voice or speed
|
| 198 |
+
output_path = dia_generate_speech(text, language=self.lang_code)
|
| 199 |
+
logger.info(f"Generated audio with Dia: {output_path}")
|
| 200 |
+
return output_path
|
| 201 |
+
|
| 202 |
+
except ImportError as import_err:
|
| 203 |
+
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
| 204 |
+
logger.error("Falling back to dummy audio generation")
|
| 205 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
| 206 |
+
|
| 207 |
+
except Exception as dia_error:
|
| 208 |
+
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
| 209 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 210 |
+
logger.error("Falling back to dummy audio generation")
|
| 211 |
+
return DummyTTSEngine().generate_speech(text, voice, speed)
|
| 212 |
+
|
| 213 |
+
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
| 214 |
+
"""Generate speech stream using Dia TTS engine
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
text (str): Input text to synthesize
|
| 218 |
+
voice (str): Voice ID (not used in Dia)
|
| 219 |
+
speed (float): Speech speed multiplier (not used in Dia)
|
| 220 |
+
|
| 221 |
+
Yields:
|
| 222 |
+
tuple: (sample_rate, audio_data) pairs for each segment
|
| 223 |
+
"""
|
| 224 |
+
logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
# Import required modules
|
| 228 |
+
import torch
|
| 229 |
+
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
| 230 |
+
|
| 231 |
+
# Get the Dia model
|
| 232 |
+
model = _get_model()
|
| 233 |
+
|
| 234 |
+
# Generate audio
|
| 235 |
+
with torch.inference_mode():
|
| 236 |
+
output_audio_np = model.generate(
|
| 237 |
+
text,
|
| 238 |
+
max_tokens=None,
|
| 239 |
+
cfg_scale=3.0,
|
| 240 |
+
temperature=1.3,
|
| 241 |
+
top_p=0.95,
|
| 242 |
+
cfg_filter_top_k=35,
|
| 243 |
+
use_torch_compile=False,
|
| 244 |
+
verbose=False
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if output_audio_np is not None:
|
| 248 |
+
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
| 249 |
+
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
| 250 |
+
else:
|
| 251 |
+
logger.warning("Dia model returned None for audio output")
|
| 252 |
+
logger.warning("Falling back to dummy audio stream")
|
| 253 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
| 254 |
+
|
| 255 |
+
except ImportError as import_err:
|
| 256 |
+
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
| 257 |
+
logger.error("Falling back to dummy audio stream")
|
| 258 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
| 259 |
+
|
| 260 |
+
except Exception as dia_error:
|
| 261 |
+
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
| 262 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 263 |
+
logger.error("Falling back to dummy audio stream")
|
| 264 |
+
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def get_available_engines() -> List[str]:
|
| 268 |
+
"""Get a list of available TTS engines
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
List[str]: List of available engine names
|
| 272 |
+
"""
|
| 273 |
+
available = []
|
| 274 |
+
|
| 275 |
+
if KOKORO_AVAILABLE:
|
| 276 |
+
available.append('kokoro')
|
| 277 |
+
|
| 278 |
+
if KOKORO_SPACE_AVAILABLE:
|
| 279 |
+
available.append('kokoro_space')
|
| 280 |
+
|
| 281 |
+
if DIA_AVAILABLE:
|
| 282 |
+
available.append('dia')
|
| 283 |
+
|
| 284 |
+
# Dummy is always available
|
| 285 |
+
available.append('dummy')
|
| 286 |
+
|
| 287 |
+
return available
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def create_engine(engine_type: str, lang_code: str = 'z') -> TTSEngineBase:
|
| 291 |
+
"""Create a specific TTS engine
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
engine_type (str): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
|
| 295 |
+
lang_code (str): Language code for the engine
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
TTSEngineBase: An instance of the requested TTS engine
|
| 299 |
+
|
| 300 |
+
Raises:
|
| 301 |
+
ValueError: If the requested engine type is not supported
|
| 302 |
+
"""
|
| 303 |
+
if engine_type == 'kokoro':
|
| 304 |
+
if not KOKORO_AVAILABLE:
|
| 305 |
+
raise ValueError("Kokoro TTS engine is not available")
|
| 306 |
+
return KokoroTTSEngine(lang_code)
|
| 307 |
+
|
| 308 |
+
elif engine_type == 'kokoro_space':
|
| 309 |
+
if not KOKORO_SPACE_AVAILABLE:
|
| 310 |
+
raise ValueError("Kokoro Space TTS engine is not available")
|
| 311 |
+
return KokoroSpaceTTSEngine(lang_code)
|
| 312 |
+
|
| 313 |
+
elif engine_type == 'dia':
|
| 314 |
+
if not DIA_AVAILABLE:
|
| 315 |
+
raise ValueError("Dia TTS engine is not available")
|
| 316 |
+
return DiaTTSEngine(lang_code)
|
| 317 |
+
|
| 318 |
+
elif engine_type == 'dummy':
|
| 319 |
+
return DummyTTSEngine(lang_code)
|
| 320 |
+
|
| 321 |
+
else:
|
| 322 |
+
raise ValueError(f"Unsupported TTS engine type: {engine_type}")
|
utils/tts_factory.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
# Configure logging
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
# Import the base class
|
| 8 |
+
from utils.tts_base import TTSEngineBase, DummyTTSEngine
|
| 9 |
+
|
| 10 |
+
class TTSFactory:
|
| 11 |
+
"""Factory class for creating TTS engines
|
| 12 |
+
|
| 13 |
+
This class is responsible for creating the appropriate TTS engine based on
|
| 14 |
+
availability and configuration.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def create_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSEngineBase:
|
| 19 |
+
"""Create a TTS engine instance
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
engine_type (str, optional): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
|
| 23 |
+
If None, the best available engine will be used
|
| 24 |
+
lang_code (str): Language code for the engine
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
TTSEngineBase: An instance of a TTS engine
|
| 28 |
+
"""
|
| 29 |
+
from utils.tts_engines import get_available_engines, create_engine
|
| 30 |
+
|
| 31 |
+
# Get available engines
|
| 32 |
+
available_engines = get_available_engines()
|
| 33 |
+
logger.info(f"Available TTS engines: {available_engines}")
|
| 34 |
+
|
| 35 |
+
# If engine_type is specified, try to create that specific engine
|
| 36 |
+
if engine_type is not None:
|
| 37 |
+
if engine_type in available_engines:
|
| 38 |
+
logger.info(f"Creating requested engine: {engine_type}")
|
| 39 |
+
return create_engine(engine_type, lang_code)
|
| 40 |
+
else:
|
| 41 |
+
logger.warning(f"Requested engine '{engine_type}' is not available")
|
| 42 |
+
|
| 43 |
+
# Try to create the best available engine
|
| 44 |
+
# Priority: kokoro > kokoro_space > dia > dummy
|
| 45 |
+
for engine in ['kokoro', 'kokoro_space', 'dia']:
|
| 46 |
+
if engine in available_engines:
|
| 47 |
+
logger.info(f"Creating best available engine: {engine}")
|
| 48 |
+
return create_engine(engine, lang_code)
|
| 49 |
+
|
| 50 |
+
# Fall back to dummy engine
|
| 51 |
+
logger.warning("No TTS engines available, falling back to dummy engine")
|
| 52 |
+
return DummyTTSEngine(lang_code)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Backward compatibility function
|
| 56 |
+
def get_tts_engine(lang_code: str = 'a') -> TTSEngineBase:
|
| 57 |
+
"""Get or create TTS engine instance (backward compatibility function)
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
lang_code (str): Language code for the pipeline
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
TTSEngineBase: Initialized TTS engine instance
|
| 64 |
+
"""
|
| 65 |
+
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
| 66 |
+
try:
|
| 67 |
+
import streamlit as st
|
| 68 |
+
logger.info("Streamlit detected, using cached TTS engine")
|
| 69 |
+
@st.cache_resource
|
| 70 |
+
def _get_engine():
|
| 71 |
+
logger.info("Creating cached TTS engine instance")
|
| 72 |
+
engine = TTSFactory.create_engine(lang_code=lang_code)
|
| 73 |
+
logger.info(f"Cached TTS engine created with type: {engine.__class__.__name__}")
|
| 74 |
+
return engine
|
| 75 |
+
|
| 76 |
+
engine = _get_engine()
|
| 77 |
+
logger.info(f"Retrieved TTS engine from cache with type: {engine.__class__.__name__}")
|
| 78 |
+
return engine
|
| 79 |
+
except ImportError:
|
| 80 |
+
logger.info("Streamlit not available, creating direct TTS engine instance")
|
| 81 |
+
engine = TTSFactory.create_engine(lang_code=lang_code)
|
| 82 |
+
logger.info(f"Direct TTS engine created with type: {engine.__class__.__name__}")
|
| 83 |
+
return engine
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Backward compatibility function
|
| 87 |
+
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 88 |
+
"""Public interface for TTS generation (backward compatibility function)
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
text (str): Input text to synthesize
|
| 92 |
+
voice (str): Voice ID to use
|
| 93 |
+
speed (float): Speech speed multiplier
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
str: Path to generated audio file
|
| 97 |
+
"""
|
| 98 |
+
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
|
| 99 |
+
try:
|
| 100 |
+
# Get the TTS engine
|
| 101 |
+
logger.info("Getting TTS engine instance")
|
| 102 |
+
engine = get_tts_engine()
|
| 103 |
+
logger.info(f"Using TTS engine type: {engine.__class__.__name__}")
|
| 104 |
+
|
| 105 |
+
# Generate speech
|
| 106 |
+
logger.info("Calling engine.generate_speech")
|
| 107 |
+
output_path = engine.generate_speech(text, voice, speed)
|
| 108 |
+
logger.info(f"Speech generation complete, output path: {output_path}")
|
| 109 |
+
return output_path
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
|
| 112 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 113 |
+
if hasattr(e, '__traceback__'):
|
| 114 |
+
tb = e.__traceback__
|
| 115 |
+
while tb.tb_next:
|
| 116 |
+
tb = tb.tb_next
|
| 117 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
| 118 |
+
raise
|