Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import logging | |
| import numpy as np | |
| import soundfile as sf | |
| from pathlib import Path | |
| from typing import Optional | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Flag to track Dia availability | |
| DIA_AVAILABLE = False | |
| # Try to import required dependencies | |
| try: | |
| import torch | |
| # Try to import Dia, which will try to import dac | |
| try: | |
| from dia.model import Dia | |
| DIA_AVAILABLE = True | |
| logger.info("Dia TTS engine is available") | |
| except ModuleNotFoundError as e: | |
| if "dac" in str(e): | |
| logger.warning("Dia TTS engine is not available due to missing 'dac' module") | |
| else: | |
| logger.warning(f"Dia TTS engine is not available: {str(e)}") | |
| DIA_AVAILABLE = False | |
| except ImportError: | |
| logger.warning("Torch not available, Dia TTS engine cannot be used") | |
| DIA_AVAILABLE = False | |
| # Constants | |
| DEFAULT_SAMPLE_RATE = 44100 | |
| DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B" | |
| # Global model instance (lazy loaded) | |
| _model = None | |
| def _get_model(): | |
| """Lazy-load the Dia model to avoid loading it until needed""" | |
| global _model | |
| # Check if Dia is available before attempting to load | |
| if not DIA_AVAILABLE: | |
| logger.warning("Dia is not available, cannot load model") | |
| raise ImportError("Dia module is not available") | |
| if _model is None: | |
| logger.info("Loading Dia model...") | |
| try: | |
| # Check if torch is available with correct version | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"CUDA version: {torch.version.cuda}") | |
| logger.info(f"GPU device: {torch.cuda.get_device_name(0)}") | |
| # Check if model path exists | |
| logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}") | |
| # Load the model with detailed logging | |
| logger.info("Initializing Dia model...") | |
| _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16") | |
| # Log model details | |
| logger.info(f"Dia model loaded successfully") | |
| logger.info(f"Model type: {type(_model).__name__}") | |
| # Check if model has parameters method (PyTorch models do, but Dia might not) | |
| if hasattr(_model, 'parameters'): | |
| logger.info(f"Model device: {next(_model.parameters()).device}") | |
| else: | |
| logger.info("Model device: Device information not available for Dia model") | |
| except ImportError as import_err: | |
| logger.error(f"Import error loading Dia model: {import_err}") | |
| logger.error(f"This may indicate missing dependencies") | |
| raise | |
| except FileNotFoundError as file_err: | |
| logger.error(f"File not found error loading Dia model: {file_err}") | |
| logger.error(f"Model path may be incorrect or inaccessible") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading Dia model: {e}", exc_info=True) | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.error(f"This may indicate incompatible versions or missing CUDA support") | |
| raise | |
| return _model | |
| def generate_speech(text: str, language: str = "zh") -> str: | |
| """Public interface for TTS generation using Dia model | |
| This is a legacy function maintained for backward compatibility. | |
| New code should use the factory pattern implementation directly. | |
| Args: | |
| text (str): Input text to synthesize | |
| language (str): Language code (not used in Dia model, kept for API compatibility) | |
| Returns: | |
| str: Path to the generated audio file | |
| """ | |
| logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}") | |
| # Check if Dia is available | |
| if not DIA_AVAILABLE: | |
| logger.warning("Dia is not available, falling back to dummy TTS engine") | |
| from utils.tts_base import DummyTTSEngine | |
| dummy_engine = DummyTTSEngine(language) | |
| return dummy_engine.generate_speech(text) | |
| # Use the new implementation via factory pattern | |
| try: | |
| # Import here to avoid circular imports | |
| from utils.tts_engines import DiaTTSEngine | |
| # Create a Dia engine and generate speech | |
| dia_engine = DiaTTSEngine(language) | |
| return dia_engine.generate_speech(text) | |
| except ModuleNotFoundError as e: | |
| logger.error(f"Module not found error in Dia generate_speech: {str(e)}") | |
| if "dac" in str(e): | |
| logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS") | |
| # Fall back to dummy TTS | |
| from utils.tts_base import DummyTTSEngine | |
| dummy_engine = DummyTTSEngine(language) | |
| return dummy_engine.generate_speech(text) | |
| except Exception as e: | |
| logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True) | |
| # Fall back to dummy TTS | |
| from utils.tts_base import DummyTTSEngine | |
| dummy_engine = DummyTTSEngine(language) | |
| return dummy_engine.generate_speech(text) |