tts_gallery / src /models /tts /dia_model.py
Michael Hu
refactor: replace inline model definitions with ModelFactory and remove unused imports
ef4db28
import tempfile
import os
from ..base import TTSModel
class DiaTTSModel(TTSModel):
"""Dia TTS model implementation"""
def __init__(self):
self._model = None
self._initialized = False
@property
def name(self):
return "nari-labs/Dia-1.6B"
@property
def description(self):
return "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions"
def initialize(self):
"""Initialize the Dia model"""
if self._initialized:
return True
try:
# Import here to avoid circular imports
from src.dia_tts import DiaTTS
self._model = DiaTTS()
self._initialized = True
return True
except Exception as e:
print(f"Error initializing Dia model: {e}")
return False
def generate_speech(self, text, audio_prompt=None, **kwargs):
"""
Generate speech from text using Dia TTS
Args:
text (str): Text to convert to speech
audio_prompt (str, optional): Path to reference audio file for voice cloning
**kwargs: Additional parameters for generation
Returns:
str: Path to the generated audio file
"""
if not self._initialized:
if not self.initialize():
raise RuntimeError("Failed to initialize Dia model")
# Generate speech using Dia
output_path = self._model.generate(text, reference_audio=audio_prompt, **kwargs)
return output_path
def supports_voice_cloning(self):
return True