PodXplain / tts_engine.py
Nick021402's picture
Update tts_engine.py
f6c76a6 verified
# tts_engine.py - TTS engine wrapper for Nari DIA
import logging
import os
from typing import Optional
import tempfile
import numpy as np
import soundfile as sf
import torch # Import torch for model operations
# Import the actual Nari DIA model
try:
from dia.model import Dia
except ImportError:
logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.")
Dia = None # Set to None to prevent further errors
logger = logging.getLogger(__name__)
class NariDIAEngine:
def __init__(self):
self.model = None
# No separate processor object for Dia, it handles internal processing
self._initialize_model()
def _initialize_model(self):
"""Initialize the Nari DIA 1.6B model."""
if Dia is None:
logger.error("Nari DIA library is not available. Cannot initialize model.")
return
try:
logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...")
# Load the Nari DIA model
# Use compute_dtype="float16" for potentially better performance/memory on GPU
# Ensure you have a GPU with ~10GB VRAM for this.
self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
# Move model to GPU if available
if torch.cuda.is_available():
self.model.to("cuda")
logger.info("Nari DIA model moved to GPU (CUDA).")
else:
logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.")
logger.info("Nari DIA model initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
self.model = None
def synthesize_segment(
self,
text: str,
speaker: str, # This will be 'S1' or 'S2' from segmenter
output_path: str
) -> Optional[str]:
"""
Synthesize speech for a text segment using Nari DIA.
Args:
text: Text to synthesize
speaker: Speaker identifier ('S1' or 'S2' expected from segmenter)
output_path: Path to save the audio file
Returns:
Path to the generated audio file, or None if failed
"""
if not self.model:
logger.error("Nari DIA model not initialized. Cannot synthesize speech.")
return None
try:
# Nari DIA expects [S1] or [S2] tags.
# The segmenter is directly outputting "S1" or "S2".
# We just need to wrap it in brackets.
if speaker in ["S1", "S2"]:
dia_speaker_tag = f"[{speaker}]"
else:
# Fallback in case segmenter outputs something unexpected
logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
dia_speaker_tag = "[S1]"
# Nari DIA expects the speaker tag at the beginning of the segment
full_text_input = f"{dia_speaker_tag} {text}"
# Generate audio using the Nari DIA model
logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text
# Pass the text directly to the model's generate method
# Nari DIA's Dia class handles internal processing/tokenization
with torch.no_grad():
# The .generate method should return audio waveform as a PyTorch tensor
audio_waveform_tensor = self.model.generate(full_text_input)
audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze()
# Nari DIA's sampling rate is typically 22050 Hz.
# If the Dia model object itself exposes a sampling_rate attribute, use it.
# Otherwise, default to 22050 as it's common for TTS models.
sampling_rate = getattr(self.model, 'sampling_rate', 22050)
# Save as WAV file
sf.write(output_path, audio_waveform, sampling_rate)
logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}")
return output_path
except Exception as e:
logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
return None