# tts_engine.py - TTS engine wrapper for Nari DIA import logging import os from typing import Optional import tempfile import numpy as np import soundfile as sf import torch # Import torch for model operations # Import the actual Nari DIA model try: from dia.model import Dia except ImportError: logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.") Dia = None # Set to None to prevent further errors logger = logging.getLogger(__name__) class NariDIAEngine: def __init__(self): self.model = None # No separate processor object for Dia, it handles internal processing self._initialize_model() def _initialize_model(self): """Initialize the Nari DIA 1.6B model.""" if Dia is None: logger.error("Nari DIA library is not available. Cannot initialize model.") return try: logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...") # Load the Nari DIA model # Use compute_dtype="float16" for potentially better performance/memory on GPU # Ensure you have a GPU with ~10GB VRAM for this. self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16") # Move model to GPU if available if torch.cuda.is_available(): self.model.to("cuda") logger.info("Nari DIA model moved to GPU (CUDA).") else: logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.") logger.info("Nari DIA model initialized successfully.") except Exception as e: logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True) self.model = None def synthesize_segment( self, text: str, speaker: str, # This will be 'S1' or 'S2' from segmenter output_path: str ) -> Optional[str]: """ Synthesize speech for a text segment using Nari DIA. Args: text: Text to synthesize speaker: Speaker identifier ('S1' or 'S2' expected from segmenter) output_path: Path to save the audio file Returns: Path to the generated audio file, or None if failed """ if not self.model: logger.error("Nari DIA model not initialized. Cannot synthesize speech.") return None try: # Nari DIA expects [S1] or [S2] tags. # The segmenter is directly outputting "S1" or "S2". # We just need to wrap it in brackets. if speaker in ["S1", "S2"]: dia_speaker_tag = f"[{speaker}]" else: # Fallback in case segmenter outputs something unexpected logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].") dia_speaker_tag = "[S1]" # Nari DIA expects the speaker tag at the beginning of the segment full_text_input = f"{dia_speaker_tag} {text}" # Generate audio using the Nari DIA model logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text # Pass the text directly to the model's generate method # Nari DIA's Dia class handles internal processing/tokenization with torch.no_grad(): # The .generate method should return audio waveform as a PyTorch tensor audio_waveform_tensor = self.model.generate(full_text_input) audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze() # Nari DIA's sampling rate is typically 22050 Hz. # If the Dia model object itself exposes a sampling_rate attribute, use it. # Otherwise, default to 22050 as it's common for TTS models. sampling_rate = getattr(self.model, 'sampling_rate', 22050) # Save as WAV file sf.write(output_path, audio_waveform, sampling_rate) logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}") return output_path except Exception as e: logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback return None