Spaces:
Sleeping
Sleeping
| """ | |
| Dia TTS model integration for TTS Gallery | |
| Based on: https://github.com/nari-labs/dia/blob/main/hf.py | |
| """ | |
| import tempfile | |
| import torch | |
| import soundfile as sf | |
| # from transformers import AutoProcessor, DiaForConditionalGeneration | |
| # class DiaTTS: | |
| # """ | |
| # Wrapper for the Dia TTS model from Nari Labs | |
| # """ | |
| # def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"): | |
| # """ | |
| # Initialize the Dia TTS model | |
| # | |
| # Args: | |
| # model_checkpoint (str): HuggingFace model checkpoint to use | |
| # """ | |
| # self.model_checkpoint = model_checkpoint | |
| # self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # | |
| # # Load processor and model | |
| # self.processor = AutoProcessor.from_pretrained(model_checkpoint) | |
| # self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device) | |
| # | |
| # # Default generation parameters | |
| # self.generation_params = { | |
| # "max_new_tokens": 3072, | |
| # "guidance_scale": 3.0, | |
| # "temperature": 1.8, | |
| # "top_p": 0.90, | |
| # "top_k": 45 | |
| # } | |
| # | |
| # def generate(self, text, audio_prompt=None): | |
| # """ | |
| # Generate speech from text using Dia | |
| # | |
| # Args: | |
| # text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue. | |
| # audio_prompt (str, optional): Path to reference audio file for voice cloning | |
| # | |
| # Returns: | |
| # numpy.ndarray: Generated audio as a numpy array | |
| # int: Sample rate (44100) | |
| # """ | |
| # # Format text with speaker tags if not already present | |
| # if not text.startswith("[S1]") and not text.startswith("[S2]"): | |
| # text = f"[S1] {text}" | |
| # | |
| # # Prepare inputs | |
| # inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device) | |
| # | |
| # # Generate audio | |
| # outputs = self.model.generate(**inputs, **self.generation_params) | |
| # | |
| # # Decode outputs | |
| # audio_data = self.processor.batch_decode(outputs) | |
| # | |
| # # Return audio data (assuming it's a numpy array) and sample rate | |
| # return audio_data[0], 44100 # Dia uses 44.1kHz sample rate | |
| # | |
| # def generate_to_file(self, text, audio_prompt=None): | |
| # """ | |
| # Generate speech from text and save to a temporary file | |
| # | |
| # Args: | |
| # text (str): Text to convert to speech | |
| # audio_prompt (str, optional): Path to reference audio file for voice cloning | |
| # | |
| # Returns: | |
| # str: Path to the generated audio file | |
| # """ | |
| # audio_data, sample_rate = self.generate(text, audio_prompt) | |
| # | |
| # # Save to a temporary file | |
| # with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file: | |
| # sf.write(tmp_file.name, audio_data, sample_rate) | |
| # return tmp_file.name |