""" Dia TTS model integration for TTS Gallery Based on: https://github.com/nari-labs/dia/blob/main/hf.py """ import tempfile import torch import soundfile as sf # from transformers import AutoProcessor, DiaForConditionalGeneration # class DiaTTS: # """ # Wrapper for the Dia TTS model from Nari Labs # """ # def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"): # """ # Initialize the Dia TTS model # # Args: # model_checkpoint (str): HuggingFace model checkpoint to use # """ # self.model_checkpoint = model_checkpoint # self.device = "cuda" if torch.cuda.is_available() else "cpu" # # # Load processor and model # self.processor = AutoProcessor.from_pretrained(model_checkpoint) # self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device) # # # Default generation parameters # self.generation_params = { # "max_new_tokens": 3072, # "guidance_scale": 3.0, # "temperature": 1.8, # "top_p": 0.90, # "top_k": 45 # } # # def generate(self, text, audio_prompt=None): # """ # Generate speech from text using Dia # # Args: # text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue. # audio_prompt (str, optional): Path to reference audio file for voice cloning # # Returns: # numpy.ndarray: Generated audio as a numpy array # int: Sample rate (44100) # """ # # Format text with speaker tags if not already present # if not text.startswith("[S1]") and not text.startswith("[S2]"): # text = f"[S1] {text}" # # # Prepare inputs # inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device) # # # Generate audio # outputs = self.model.generate(**inputs, **self.generation_params) # # # Decode outputs # audio_data = self.processor.batch_decode(outputs) # # # Return audio data (assuming it's a numpy array) and sample rate # return audio_data[0], 44100 # Dia uses 44.1kHz sample rate # # def generate_to_file(self, text, audio_prompt=None): # """ # Generate speech from text and save to a temporary file # # Args: # text (str): Text to convert to speech # audio_prompt (str, optional): Path to reference audio file for voice cloning # # Returns: # str: Path to the generated audio file # """ # audio_data, sample_rate = self.generate(text, audio_prompt) # # # Save to a temporary file # with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file: # sf.write(tmp_file.name, audio_data, sample_rate) # return tmp_file.name