Spaces:
Sleeping
Sleeping
| # tts_engine.py - TTS engine wrapper for Nari DIA | |
| import logging | |
| import os | |
| from typing import Optional | |
| import tempfile | |
| import numpy as np | |
| import soundfile as sf | |
| import torch # Import torch for model operations | |
| # Import the actual Nari DIA model | |
| try: | |
| from dia.model import Dia | |
| except ImportError: | |
| logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.") | |
| Dia = None # Set to None to prevent further errors | |
| logger = logging.getLogger(__name__) | |
| class NariDIAEngine: | |
| def __init__(self): | |
| self.model = None | |
| # No separate processor object for Dia, it handles internal processing | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| """Initialize the Nari DIA 1.6B model.""" | |
| if Dia is None: | |
| logger.error("Nari DIA library is not available. Cannot initialize model.") | |
| return | |
| try: | |
| logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...") | |
| # Load the Nari DIA model | |
| # Use compute_dtype="float16" for potentially better performance/memory on GPU | |
| # Ensure you have a GPU with ~10GB VRAM for this. | |
| self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16") | |
| # Move model to GPU if available | |
| if torch.cuda.is_available(): | |
| self.model.to("cuda") | |
| logger.info("Nari DIA model moved to GPU (CUDA).") | |
| else: | |
| logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.") | |
| logger.info("Nari DIA model initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True) | |
| self.model = None | |
| def synthesize_segment( | |
| self, | |
| text: str, | |
| speaker: str, # This will be 'S1' or 'S2' from segmenter | |
| output_path: str | |
| ) -> Optional[str]: | |
| """ | |
| Synthesize speech for a text segment using Nari DIA. | |
| Args: | |
| text: Text to synthesize | |
| speaker: Speaker identifier ('S1' or 'S2' expected from segmenter) | |
| output_path: Path to save the audio file | |
| Returns: | |
| Path to the generated audio file, or None if failed | |
| """ | |
| if not self.model: | |
| logger.error("Nari DIA model not initialized. Cannot synthesize speech.") | |
| return None | |
| try: | |
| # Nari DIA expects [S1] or [S2] tags. | |
| # The segmenter is directly outputting "S1" or "S2". | |
| # We just need to wrap it in brackets. | |
| if speaker in ["S1", "S2"]: | |
| dia_speaker_tag = f"[{speaker}]" | |
| else: | |
| # Fallback in case segmenter outputs something unexpected | |
| logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].") | |
| dia_speaker_tag = "[S1]" | |
| # Nari DIA expects the speaker tag at the beginning of the segment | |
| full_text_input = f"{dia_speaker_tag} {text}" | |
| # Generate audio using the Nari DIA model | |
| logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text | |
| # Pass the text directly to the model's generate method | |
| # Nari DIA's Dia class handles internal processing/tokenization | |
| with torch.no_grad(): | |
| # The .generate method should return audio waveform as a PyTorch tensor | |
| audio_waveform_tensor = self.model.generate(full_text_input) | |
| audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze() | |
| # Nari DIA's sampling rate is typically 22050 Hz. | |
| # If the Dia model object itself exposes a sampling_rate attribute, use it. | |
| # Otherwise, default to 22050 as it's common for TTS models. | |
| sampling_rate = getattr(self.model, 'sampling_rate', 22050) | |
| # Save as WAV file | |
| sf.write(output_path, audio_waveform, sampling_rate) | |
| logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback | |
| return None | |