Spaces:
Running
Running
| # engine.py | |
| # Core Dia TTS model loading and generation logic | |
| import logging | |
| import time | |
| import os | |
| import torch | |
| import numpy as np | |
| from typing import Optional, Tuple | |
| from huggingface_hub import hf_hub_download # Import downloader | |
| # Import Dia model class and config | |
| try: | |
| from dia.model import Dia | |
| from dia.config import DiaConfig | |
| except ImportError as e: | |
| # Log critical error if core components are missing | |
| logging.critical( | |
| f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.", | |
| exc_info=True, | |
| ) | |
| # Define dummy classes/functions to prevent server crash on import, | |
| # but generation will fail later if these are used. | |
| class Dia: | |
| def load_model_from_files(*args, **kwargs): | |
| raise RuntimeError("Dia model package not available or failed to import.") | |
| def generate(*args, **kwargs): | |
| raise RuntimeError("Dia model package not available or failed to import.") | |
| class DiaConfig: | |
| pass | |
| # Import configuration getters from our project's config.py | |
| from config import ( | |
| get_model_repo_id, | |
| get_model_cache_path, | |
| get_reference_audio_path, | |
| get_model_config_filename, | |
| get_model_weights_filename, | |
| ) | |
| logger = logging.getLogger(__name__) # Use standard logger name | |
| # --- Global Variables --- | |
| dia_model: Optional[Dia] = None | |
| # model_config is now loaded within Dia.load_model_from_files, maybe remove global? | |
| # Let's keep it for now if needed elsewhere, but populate it after loading. | |
| model_config_instance: Optional[DiaConfig] = None | |
| model_device: Optional[torch.device] = None | |
| MODEL_LOADED = False | |
| EXPECTED_SAMPLE_RATE = 44100 # Dia model and DAC typically operate at 44.1kHz | |
| # --- Model Loading --- | |
| def get_device() -> torch.device: | |
| """Determines the optimal torch device (CUDA > MPS > CPU).""" | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA is available, using GPU.") | |
| return torch.device("cuda") | |
| # Add MPS check for Apple Silicon GPUs | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| # Basic check is usually sufficient | |
| logger.info("MPS is available, using Apple Silicon GPU.") | |
| return torch.device("mps") | |
| else: | |
| logger.info("CUDA and MPS not available, using CPU.") | |
| return torch.device("cpu") | |
| def load_model(): | |
| """ | |
| Loads the Dia TTS model and associated DAC model. | |
| Downloads model files based on configuration if they don't exist locally. | |
| Handles both .pth and .safetensors formats. | |
| """ | |
| global dia_model, model_config_instance, model_device, MODEL_LOADED | |
| if MODEL_LOADED: | |
| logger.info("Dia model already loaded.") | |
| return True | |
| # Get configuration values | |
| repo_id = get_model_repo_id() | |
| config_filename = get_model_config_filename() | |
| weights_filename = get_model_weights_filename() | |
| cache_path = get_model_cache_path() # Already absolute path | |
| model_device = get_device() | |
| logger.info(f"Attempting to load Dia model:") | |
| logger.info(f" Repo ID: {repo_id}") | |
| logger.info(f" Config File: {config_filename}") | |
| logger.info(f" Weights File: {weights_filename}") | |
| logger.info(f" Cache Directory: {cache_path}") | |
| logger.info(f" Target Device: {model_device}") | |
| # Ensure cache directory exists | |
| try: | |
| os.makedirs(cache_path, exist_ok=True) | |
| except OSError as e: | |
| logger.error( | |
| f"Failed to create cache directory '{cache_path}': {e}", exc_info=True | |
| ) | |
| # Depending on severity, might want to return False here | |
| # return False | |
| pass # Continue and let hf_hub_download handle potential issues | |
| try: | |
| start_time = time.time() | |
| # --- Download Model Files --- | |
| logger.info( | |
| f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..." | |
| ) | |
| local_config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=config_filename, | |
| cache_dir=cache_path, | |
| # force_download=False, # Default: only download if missing or outdated | |
| # resume_download=True, # Default: resume interrupted downloads | |
| ) | |
| logger.info(f"Configuration file path: {local_config_path}") | |
| logger.info( | |
| f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..." | |
| ) | |
| local_weights_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=weights_filename, | |
| cache_dir=cache_path, | |
| ) | |
| logger.info(f"Weights file path: {local_weights_path}") | |
| # --- Load Model using the class method --- | |
| # The Dia class method now handles config loading, instantiation, weight loading, etc. | |
| dia_model = Dia.load_model_from_files( | |
| config_path=local_config_path, | |
| weights_path=local_weights_path, | |
| device=model_device, | |
| ) | |
| # Store the config instance if needed globally (optional) | |
| model_config_instance = dia_model.config | |
| end_time = time.time() | |
| logger.info( | |
| f"Dia model loaded successfully in {end_time - start_time:.2f} seconds." | |
| ) | |
| MODEL_LOADED = True | |
| return True | |
| except FileNotFoundError as e: | |
| logger.error( | |
| f"Model loading failed: Required file not found. {e}", exc_info=True | |
| ) | |
| MODEL_LOADED = False | |
| return False | |
| except ImportError: | |
| # This catches if the 'dia' package itself is missing | |
| logger.critical( | |
| "Failed to load model: Dia package or its core dependencies not found.", | |
| exc_info=True, | |
| ) | |
| MODEL_LOADED = False | |
| return False | |
| except Exception as e: | |
| # Catch other potential errors during download or loading | |
| logger.error( | |
| f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True | |
| ) | |
| dia_model = None | |
| model_config_instance = None | |
| MODEL_LOADED = False | |
| return False | |
| # --- Speech Generation --- | |
| def generate_speech( | |
| text: str, | |
| voice_mode: str = "single_s1", | |
| clone_reference_filename: Optional[str] = None, | |
| max_tokens: Optional[int] = None, | |
| cfg_scale: float = 3.0, | |
| temperature: float = 1.3, | |
| top_p: float = 0.95, | |
| speed_factor: float = 0.94, # Keep speed factor separate from model generation params | |
| cfg_filter_top_k: int = 35, | |
| ) -> Optional[Tuple[np.ndarray, int]]: | |
| """ | |
| Generates speech using the loaded Dia model, handling voice modes and speed adjustment. | |
| Args: | |
| text: Text to synthesize. | |
| voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'. | |
| clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path. | |
| max_tokens: Max generation tokens for the model's generate method. | |
| cfg_scale: CFG scale for the model's generate method. | |
| temperature: Sampling temperature for the model's generate method. | |
| top_p: Nucleus sampling p for the model's generate method. | |
| speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster). | |
| cfg_filter_top_k: CFG filter top K for the model's generate method. | |
| Returns: | |
| Tuple of (numpy_audio_array, sample_rate), or None on failure. | |
| """ | |
| if not MODEL_LOADED or dia_model is None: | |
| logger.error("Dia model is not loaded. Cannot generate speech.") | |
| return None | |
| logger.info(f"Generating speech with mode: {voice_mode}") | |
| logger.debug(f"Input text (start): '{text[:100]}...'") | |
| # Log model generation parameters | |
| logger.debug( | |
| f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}" | |
| ) | |
| # Log post-processing parameters | |
| logger.debug(f"Post-processing Params: speed_factor={speed_factor}") | |
| audio_prompt_path = None | |
| processed_text = text # Start with original text | |
| # --- Handle Voice Mode --- | |
| if voice_mode == "clone": | |
| if not clone_reference_filename: | |
| logger.error("Clone mode selected but no reference filename provided.") | |
| return None | |
| ref_base_path = get_reference_audio_path() # Gets absolute path | |
| potential_path = os.path.join(ref_base_path, clone_reference_filename) | |
| if os.path.isfile(potential_path): | |
| audio_prompt_path = potential_path | |
| logger.info(f"Using audio prompt for cloning: {audio_prompt_path}") | |
| # Dia requires the transcript of the clone audio to be prepended to the target text. | |
| # The UI/API caller is responsible for constructing this combined text. | |
| logger.warning( | |
| "Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')." | |
| ) | |
| processed_text = text # Use the combined text provided by the caller | |
| else: | |
| logger.error(f"Reference audio file not found: {potential_path}") | |
| return None # Fail generation if reference file is missing | |
| elif voice_mode == "dialogue": | |
| # Assume text already contains [S1]/[S2] tags as required by the model | |
| logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.") | |
| if "[S1]" not in text and "[S2]" not in text: | |
| logger.warning( | |
| "Dialogue mode selected, but no [S1] or [S2] tags found in the input text." | |
| ) | |
| processed_text = text # Pass directly | |
| elif voice_mode == "single_s1": | |
| logger.info("Using single voice mode (S1).") | |
| # Check if text *already* contains tags, warn if so, as it might confuse the model | |
| if "[S1]" in text or "[S2]" in text: | |
| logger.warning( | |
| "Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected." | |
| ) | |
| # Dia likely expects tags even for single speaker. Prepending [S1] might be safer. | |
| # Let's assume for now the model handles untagged text as S1, but this could be adjusted. | |
| # Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag | |
| processed_text = text # Pass directly for now | |
| elif voice_mode == "single_s2": | |
| logger.info("Using single voice mode (S2).") | |
| if "[S1]" in text or "[S2]" in text: | |
| logger.warning( | |
| "Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected." | |
| ) | |
| # Similar to S1, how to signal S2? Prepending [S2] seems logical if needed. | |
| # Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag | |
| processed_text = text # Pass directly for now | |
| else: | |
| logger.error( | |
| f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'." | |
| ) | |
| processed_text = text # Fallback | |
| # --- Call Dia Generate --- | |
| try: | |
| start_time = time.time() | |
| logger.info("Calling Dia model generate method...") | |
| # Call the model's generate method with appropriate parameters | |
| generated_audio_np = dia_model.generate( | |
| text=processed_text, | |
| audio_prompt_path=audio_prompt_path, | |
| max_tokens=max_tokens, # Pass None if not specified, Dia uses its default | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_p=top_p, | |
| use_cfg_filter=True, # Default from Dia's app.py, seems reasonable | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| use_torch_compile=False, # Keep False for stability unless specifically tested/enabled | |
| ) | |
| gen_end_time = time.time() | |
| logger.info( | |
| f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds." | |
| ) | |
| if generated_audio_np is None or generated_audio_np.size == 0: | |
| logger.warning("Dia model returned None or empty audio array.") | |
| return None | |
| # --- Apply Speed Factor (Post-processing) --- | |
| # This mimics the logic in Dia's original app.py | |
| if speed_factor != 1.0: | |
| logger.info(f"Applying speed factor: {speed_factor}") | |
| original_len = len(generated_audio_np) | |
| # Ensure speed_factor is within a reasonable range to avoid extreme distortion | |
| # Adjust range based on observed quality (e.g., 0.5 to 2.0) | |
| speed_factor = max(0.5, min(speed_factor, 2.0)) | |
| target_len = int(original_len / speed_factor) | |
| if target_len > 0 and target_len != original_len: | |
| logger.debug( | |
| f"Resampling audio from {original_len} to {target_len} samples." | |
| ) | |
| # Create time axes for original and resampled audio | |
| x_original = np.linspace(0, original_len - 1, original_len) | |
| x_resampled = np.linspace(0, original_len - 1, target_len) | |
| # Interpolate using numpy | |
| resampled_audio_np = np.interp( | |
| x_resampled, x_original, generated_audio_np | |
| ) | |
| final_audio_np = resampled_audio_np.astype(np.float32) # Ensure float32 | |
| logger.info(f"Audio resampled for {speed_factor:.2f}x speed.") | |
| else: | |
| logger.warning( | |
| f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed." | |
| ) | |
| final_audio_np = generated_audio_np # Use original audio | |
| else: | |
| logger.info("Speed factor is 1.0, no speed adjustment needed.") | |
| final_audio_np = generated_audio_np # No speed change needed | |
| # Ensure output is float32 (DAC output should be, but good practice) | |
| if final_audio_np.dtype != np.float32: | |
| logger.warning( | |
| f"Generated audio was not float32 ({final_audio_np.dtype}), converting." | |
| ) | |
| final_audio_np = final_audio_np.astype(np.float32) | |
| logger.info( | |
| f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}" | |
| ) | |
| # Return the processed audio and the expected sample rate | |
| return final_audio_np, EXPECTED_SAMPLE_RATE | |
| except Exception as e: | |
| logger.error( | |
| f"Error during Dia generation or post-processing: {e}", exc_info=True | |
| ) | |
| return None # Return None on failure | |