tts_gallery / src /dia_tts.py
Michael Hu
refactor: remove DiaTTS integration and related UI elements
3d5e706
"""
Dia TTS model integration for TTS Gallery
Based on: https://github.com/nari-labs/dia/blob/main/hf.py
"""
import tempfile
import torch
import soundfile as sf
# from transformers import AutoProcessor, DiaForConditionalGeneration
# class DiaTTS:
# """
# Wrapper for the Dia TTS model from Nari Labs
# """
# def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"):
# """
# Initialize the Dia TTS model
#
# Args:
# model_checkpoint (str): HuggingFace model checkpoint to use
# """
# self.model_checkpoint = model_checkpoint
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
#
# # Load processor and model
# self.processor = AutoProcessor.from_pretrained(model_checkpoint)
# self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device)
#
# # Default generation parameters
# self.generation_params = {
# "max_new_tokens": 3072,
# "guidance_scale": 3.0,
# "temperature": 1.8,
# "top_p": 0.90,
# "top_k": 45
# }
#
# def generate(self, text, audio_prompt=None):
# """
# Generate speech from text using Dia
#
# Args:
# text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue.
# audio_prompt (str, optional): Path to reference audio file for voice cloning
#
# Returns:
# numpy.ndarray: Generated audio as a numpy array
# int: Sample rate (44100)
# """
# # Format text with speaker tags if not already present
# if not text.startswith("[S1]") and not text.startswith("[S2]"):
# text = f"[S1] {text}"
#
# # Prepare inputs
# inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device)
#
# # Generate audio
# outputs = self.model.generate(**inputs, **self.generation_params)
#
# # Decode outputs
# audio_data = self.processor.batch_decode(outputs)
#
# # Return audio data (assuming it's a numpy array) and sample rate
# return audio_data[0], 44100 # Dia uses 44.1kHz sample rate
#
# def generate_to_file(self, text, audio_prompt=None):
# """
# Generate speech from text and save to a temporary file
#
# Args:
# text (str): Text to convert to speech
# audio_prompt (str, optional): Path to reference audio file for voice cloning
#
# Returns:
# str: Path to the generated audio file
# """
# audio_data, sample_rate = self.generate(text, audio_prompt)
#
# # Save to a temporary file
# with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
# sf.write(tmp_file.name, audio_data, sample_rate)
# return tmp_file.name