Spaces:
Sleeping
Sleeping
File size: 4,666 Bytes
6b10d56 9b03592 6b10d56 f6c76a6 6b10d56 9b03592 6b10d56 9b03592 6b10d56 9b03592 6b10d56 9b03592 6b10d56 daae1c2 6b10d56 daae1c2 6b10d56 9b03592 6b10d56 daae1c2 6b10d56 9b03592 6b10d56 9b03592 f6c76a6 daae1c2 9b03592 daae1c2 9b03592 6b10d56 f6c76a6 9b03592 f6c76a6 6b10d56 9b03592 6b10d56 9b03592 6b10d56 9b03592 6b10d56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# tts_engine.py - TTS engine wrapper for Nari DIA
import logging
import os
from typing import Optional
import tempfile
import numpy as np
import soundfile as sf
import torch # Import torch for model operations
# Import the actual Nari DIA model
try:
from dia.model import Dia
except ImportError:
logging.error("Nari DIA library not found. Please ensure 'git+https://github.com/nari-labs/dia.git' is in your requirements.txt and installed.")
Dia = None # Set to None to prevent further errors
logger = logging.getLogger(__name__)
class NariDIAEngine:
def __init__(self):
self.model = None
# No separate processor object for Dia, it handles internal processing
self._initialize_model()
def _initialize_model(self):
"""Initialize the Nari DIA 1.6B model."""
if Dia is None:
logger.error("Nari DIA library is not available. Cannot initialize model.")
return
try:
logger.info("Initializing Nari DIA 1.6B model from nari-labs/Dia-1.6B...")
# Load the Nari DIA model
# Use compute_dtype="float16" for potentially better performance/memory on GPU
# Ensure you have a GPU with ~10GB VRAM for this.
self.model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
# Move model to GPU if available
if torch.cuda.is_available():
self.model.to("cuda")
logger.info("Nari DIA model moved to GPU (CUDA).")
else:
logger.warning("CUDA not available. Nari DIA model will run on CPU, which is not officially supported and will be very slow.")
logger.info("Nari DIA model initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize Nari DIA model: {e}", exc_info=True)
self.model = None
def synthesize_segment(
self,
text: str,
speaker: str, # This will be 'S1' or 'S2' from segmenter
output_path: str
) -> Optional[str]:
"""
Synthesize speech for a text segment using Nari DIA.
Args:
text: Text to synthesize
speaker: Speaker identifier ('S1' or 'S2' expected from segmenter)
output_path: Path to save the audio file
Returns:
Path to the generated audio file, or None if failed
"""
if not self.model:
logger.error("Nari DIA model not initialized. Cannot synthesize speech.")
return None
try:
# Nari DIA expects [S1] or [S2] tags.
# The segmenter is directly outputting "S1" or "S2".
# We just need to wrap it in brackets.
if speaker in ["S1", "S2"]:
dia_speaker_tag = f"[{speaker}]"
else:
# Fallback in case segmenter outputs something unexpected
logger.warning(f"Unexpected speaker tag '{speaker}' from segmenter. Defaulting to [S1].")
dia_speaker_tag = "[S1]"
# Nari DIA expects the speaker tag at the beginning of the segment
full_text_input = f"{dia_speaker_tag} {text}"
# Generate audio using the Nari DIA model
logger.info(f"Synthesizing with Nari DIA: {full_text_input[:100]}...") # Log beginning of text
# Pass the text directly to the model's generate method
# Nari DIA's Dia class handles internal processing/tokenization
with torch.no_grad():
# The .generate method should return audio waveform as a PyTorch tensor
audio_waveform_tensor = self.model.generate(full_text_input)
audio_waveform = audio_waveform_tensor.cpu().numpy().squeeze()
# Nari DIA's sampling rate is typically 22050 Hz.
# If the Dia model object itself exposes a sampling_rate attribute, use it.
# Otherwise, default to 22050 as it's common for TTS models.
sampling_rate = getattr(self.model, 'sampling_rate', 22050)
# Save as WAV file
sf.write(output_path, audio_waveform, sampling_rate)
logger.info(f"Generated audio for {speaker} ({dia_speaker_tag}): {len(text)} characters to {output_path}")
return output_path
except Exception as e:
logger.error(f"Failed to synthesize segment with Nari DIA: {e}", exc_info=True) # exc_info to print full traceback
return None
|