Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import time | |
| import soundfile as sf | |
| from gradio_client import Client | |
| logger = logging.getLogger(__name__) | |
| # Flag to track TTS engine availability | |
| KOKORO_AVAILABLE = False | |
| KOKORO_SPACE_AVAILABLE = False | |
| DIA_AVAILABLE = False | |
| # Try to import Kokoro first | |
| try: | |
| from kokoro import KPipeline | |
| KOKORO_AVAILABLE = True | |
| logger.info("Kokoro TTS engine is available") | |
| except AttributeError as e: | |
| # Specifically catch the EspeakWrapper.set_data_path error | |
| if "EspeakWrapper" in str(e) and "set_data_path" in str(e): | |
| logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server") | |
| if result: | |
| KOKORO_SPACE_AVAILABLE = True | |
| else: | |
| # Re-raise if it's a different error | |
| logger.error(f"Kokoro import failed with unexpected error: {str(e)}") | |
| raise | |
| except ImportError: | |
| logger.warning("Kokoro TTS engine is not available") | |
| class TTSEngine: | |
| def __init__(self, lang_code='z'): | |
| """Initialize TTS Engine with Kokoro or Dia as fallback | |
| Args: | |
| lang_code (str): Language code ('a' for US English, 'b' for British English, | |
| 'j' for Japanese, 'z' for Mandarin Chinese) | |
| Note: lang_code is only used for Kokoro, not for Dia | |
| """ | |
| logger.info("Initializing TTS Engine") | |
| logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}") | |
| self.engine_type = None | |
| if KOKORO_AVAILABLE: | |
| logger.info(f"Using Kokoro as primary TTS engine with language code: {lang_code}") | |
| try: | |
| self.pipeline = KPipeline(lang_code=lang_code) | |
| self.engine_type = "kokoro" | |
| logger.info("TTS engine successfully initialized with Kokoro") | |
| except Exception as kokoro_err: | |
| logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}") | |
| logger.error(f"Error type: {type(kokoro_err).__name__}") | |
| logger.info("Will try to fall back to Dia TTS engine") | |
| if KOKORO_SPACE_AVAILABLE: | |
| logger.info(f"Using Kokoro FastAPI server as primary TTS engine with language code: {lang_code}") | |
| try: | |
| self.client = Client("Remsky/Kokoro-TTS-Zero") | |
| self.engine_type = "kokoro_space" | |
| logger.info("TTS engine successfully initialized with Kokoro FastAPI server") | |
| except Exception as kokoro_err: | |
| logger.error(f"Failed to initialize Kokoro space: {str(kokoro_err)}") | |
| logger.error(f"Error type: {type(kokoro_err).__name__}") | |
| logger.info("Will try to fall back to Dia TTS engine") | |
| # Try Dia if Kokoro is not available or failed to initialize | |
| if self.engine_type is None and DIA_AVAILABLE: | |
| logger.info("Using Dia as fallback TTS engine") | |
| # For Dia, we don't need to initialize anything here | |
| # The model will be lazy-loaded when needed | |
| self.pipeline = None | |
| self.client = None | |
| self.engine_type = "dia" | |
| logger.info("TTS engine initialized with Dia (lazy loading)") | |
| # Use dummy if no TTS engines are available | |
| if self.engine_type is None: | |
| logger.warning("Using dummy TTS implementation as no TTS engines are available") | |
| logger.warning("Check logs above for specific errors that prevented Kokoro or Dia initialization") | |
| self.pipeline = None | |
| self.client = None | |
| self.engine_type = "dummy" | |
| def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str: | |
| """Generate speech from text using available TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
| Note: voice parameter is only used for Kokoro, not for Dia | |
| speed (float): Speech speed multiplier (0.5 to 2.0) | |
| Note: speed parameter is only used for Kokoro, not for Dia | |
| Returns: | |
| str: Path to the generated audio file | |
| """ | |
| logger.info(f"Generating speech for text length: {len(text)}") | |
| try: | |
| # Create output directory if it doesn't exist | |
| os.makedirs("temp/outputs", exist_ok=True) | |
| # Generate unique output path | |
| output_path = f"temp/outputs/output_{int(time.time())}.wav" | |
| # Use the appropriate TTS engine based on availability | |
| if self.engine_type == "kokoro": | |
| # Use Kokoro for TTS generation | |
| generator = self.pipeline(text, voice=voice, speed=speed) | |
| for _, _, audio in generator: | |
| logger.info(f"Saving Kokoro audio to {output_path}") | |
| sf.write(output_path, audio, 24000) | |
| break | |
| elif self.engine_type == "kokoro_space": | |
| # Use Kokoro FastAPI server for TTS generation | |
| logger.info("Generating speech using Kokoro FastAPI server") | |
| logger.info(f"text to generate speech on is: {text}") | |
| try: | |
| result = self.client.predict( | |
| text=text, | |
| voice_names='af_nova', | |
| speed=speed, | |
| api_name="/generate_speech_from_ui" | |
| ) | |
| logger.info(f"Received audio from Kokoro FastAPI server: {result}") | |
| except Exception as e: | |
| logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.info("Falling back to dummy audio generation") | |
| elif self.engine_type == "dia": | |
| # Use Dia for TTS generation | |
| try: | |
| logger.info("Attempting to use Dia TTS for speech generation") | |
| # Import here to avoid circular imports | |
| try: | |
| logger.info("Importing Dia speech generation module") | |
| from utils.tts_dia import generate_speech as dia_generate_speech | |
| logger.info("Successfully imported Dia speech generation function") | |
| except ImportError as import_err: | |
| logger.error(f"Failed to import Dia speech generation function: {str(import_err)}") | |
| logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}") | |
| raise | |
| # Call Dia's generate_speech function | |
| logger.info("Calling Dia's generate_speech function") | |
| output_path = dia_generate_speech(text) | |
| logger.info(f"Generated audio with Dia: {output_path}") | |
| except ImportError as import_err: | |
| logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}") | |
| logger.error("Falling back to dummy audio generation") | |
| return self._generate_dummy_audio(output_path) | |
| except Exception as dia_error: | |
| logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True) | |
| logger.error(f"Error type: {type(dia_error).__name__}") | |
| logger.error("Falling back to dummy audio generation") | |
| # Fall back to dummy audio if Dia fails | |
| return self._generate_dummy_audio(output_path) | |
| else: | |
| # Generate dummy audio as fallback | |
| return self._generate_dummy_audio(output_path) | |
| logger.info(f"Audio generation complete: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"TTS generation failed: {str(e)}", exc_info=True) | |
| raise | |
| def _generate_dummy_audio(self, output_path): | |
| """Generate a dummy audio file with a simple sine wave | |
| Args: | |
| output_path (str): Path to save the dummy audio file | |
| Returns: | |
| str: Path to the generated dummy audio file | |
| """ | |
| import numpy as np | |
| sample_rate = 24000 | |
| duration = 3.0 # seconds | |
| t = np.linspace(0, duration, int(sample_rate * duration), False) | |
| tone = np.sin(2 * np.pi * 440 * t) * 0.3 | |
| logger.info(f"Saving dummy audio to {output_path}") | |
| sf.write(output_path, tone, sample_rate) | |
| logger.info(f"Dummy audio generation complete: {output_path}") | |
| return output_path | |
| def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0): | |
| """Generate speech from text and yield each segment | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
| speed (float): Speech speed multiplier (0.5 to 2.0) | |
| Yields: | |
| tuple: (sample_rate, audio_data) pairs for each segment | |
| """ | |
| try: | |
| # Use the appropriate TTS engine based on availability | |
| if self.engine_type == "kokoro": | |
| # Use Kokoro for streaming TTS | |
| generator = self.pipeline(text, voice=voice, speed=speed) | |
| for _, _, audio in generator: | |
| yield 24000, audio | |
| elif self.engine_type == "dia": | |
| # Dia doesn't support streaming natively, so we generate the full audio | |
| # and then yield it as a single chunk | |
| try: | |
| logger.info("Attempting to use Dia TTS for speech streaming") | |
| # Import here to avoid circular imports | |
| try: | |
| logger.info("Importing required modules for Dia streaming") | |
| import torch | |
| logger.info("PyTorch successfully imported for Dia streaming") | |
| try: | |
| from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE | |
| logger.info("Successfully imported Dia model and sample rate") | |
| except ImportError as import_err: | |
| logger.error(f"Failed to import Dia model for streaming: {str(import_err)}") | |
| logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}") | |
| raise | |
| except ImportError as torch_err: | |
| logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}") | |
| raise | |
| # Get the Dia model | |
| logger.info("Getting Dia model instance") | |
| try: | |
| model = _get_model() | |
| logger.info("Successfully obtained Dia model instance") | |
| except Exception as model_err: | |
| logger.error(f"Failed to get Dia model instance: {str(model_err)}") | |
| logger.error(f"Error type: {type(model_err).__name__}") | |
| raise | |
| # Generate audio | |
| logger.info("Generating audio with Dia model") | |
| with torch.inference_mode(): | |
| output_audio_np = model.generate( | |
| text, | |
| max_tokens=None, | |
| cfg_scale=3.0, | |
| temperature=1.3, | |
| top_p=0.95, | |
| cfg_filter_top_k=35, | |
| use_torch_compile=False, | |
| verbose=False | |
| ) | |
| if output_audio_np is not None: | |
| logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})") | |
| yield DEFAULT_SAMPLE_RATE, output_audio_np | |
| else: | |
| logger.warning("Dia model returned None for audio output") | |
| logger.warning("Falling back to dummy audio stream") | |
| # Fall back to dummy audio if Dia fails | |
| yield from self._generate_dummy_audio_stream() | |
| except ImportError as import_err: | |
| logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}") | |
| logger.error("Falling back to dummy audio stream") | |
| # Fall back to dummy audio if Dia fails | |
| yield from self._generate_dummy_audio_stream() | |
| except Exception as dia_error: | |
| logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True) | |
| logger.error(f"Error type: {type(dia_error).__name__}") | |
| logger.error("Falling back to dummy audio stream") | |
| # Fall back to dummy audio if Dia fails | |
| yield from self._generate_dummy_audio_stream() | |
| else: | |
| # Generate dummy audio chunks as fallback | |
| yield from self._generate_dummy_audio_stream() | |
| except Exception as e: | |
| logger.error(f"TTS streaming failed: {str(e)}", exc_info=True) | |
| raise | |
| def _generate_dummy_audio_stream(self): | |
| """Generate dummy audio chunks with simple sine waves | |
| Yields: | |
| tuple: (sample_rate, audio_data) pairs for each dummy segment | |
| """ | |
| import numpy as np | |
| sample_rate = 24000 | |
| duration = 1.0 # seconds per chunk | |
| # Create 3 chunks of dummy audio | |
| for i in range(3): | |
| t = np.linspace(0, duration, int(sample_rate * duration), False) | |
| freq = 440 + (i * 220) # Different frequency for each chunk | |
| tone = np.sin(2 * np.pi * freq * t) * 0.3 | |
| yield sample_rate, tone | |
| # Initialize TTS engine with cache decorator if using Streamlit | |
| def get_tts_engine(lang_code='a'): | |
| """Get or create TTS engine instance | |
| Args: | |
| lang_code (str): Language code for the pipeline | |
| Returns: | |
| TTSEngine: Initialized TTS engine instance | |
| """ | |
| logger.info(f"Requesting TTS engine with language code: {lang_code}") | |
| try: | |
| import streamlit as st | |
| logger.info("Streamlit detected, using cached TTS engine") | |
| def _get_engine(): | |
| logger.info("Creating cached TTS engine instance") | |
| engine = TTSEngine(lang_code) | |
| logger.info(f"Cached TTS engine created with type: {engine.engine_type}") | |
| return engine | |
| engine = _get_engine() | |
| logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}") | |
| return engine | |
| except ImportError: | |
| logger.info("Streamlit not available, creating direct TTS engine instance") | |
| engine = TTSEngine(lang_code) | |
| logger.info(f"Direct TTS engine created with type: {engine.engine_type}") | |
| return engine | |
| def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str: | |
| """Public interface for TTS generation | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use | |
| speed (float): Speech speed multiplier | |
| Returns: | |
| str: Path to generated audio file | |
| """ | |
| logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}") | |
| try: | |
| # Get the TTS engine | |
| logger.info("Getting TTS engine instance") | |
| engine = get_tts_engine() | |
| logger.info(f"Using TTS engine type: {engine.engine_type}") | |
| # Generate speech | |
| logger.info("Calling engine.generate_speech") | |
| output_path = engine.generate_speech(text, voice, speed) | |
| logger.info(f"Speech generation complete, output path: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True) | |
| logger.error(f"Error type: {type(e).__name__}") | |
| if hasattr(e, '__traceback__'): | |
| tb = e.__traceback__ | |
| while tb.tb_next: | |
| tb = tb.tb_next | |
| logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}") | |
| raise |