Michael Hu commited on
Commit
b68dcac
·
1 Parent(s): 8829e6c

feat: add support for Dia-1.6B TTS model in TTS Gallery

Browse files
Files changed (3) hide show
  1. app.py +55 -0
  2. requirements.txt +4 -2
  3. src/dia_tts.py +82 -0
app.py CHANGED
@@ -16,6 +16,7 @@ import wave
16
  import os
17
  from faster_whisper import WhisperModel
18
  from kokoro import KPipeline
 
19
 
20
  # Model descriptions for better understanding
21
  MODEL_DESCRIPTIONS = {
@@ -24,6 +25,7 @@ MODEL_DESCRIPTIONS = {
24
  "piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
25
  "SYSTRAN/faster-whisper": "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper",
26
  "hexgrad/kokoro": "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use",
 
27
  }
28
 
29
  # Models dictionary
@@ -33,6 +35,7 @@ MODELS = {
33
  "piper-tts": "Piper (no voice cloning)",
34
  "SYSTRAN/faster-whisper": "Faster Whisper",
35
  "hexgrad/kokoro": "Kokoro-82M",
 
36
  }
37
 
38
  original_torch_load = torch.load
@@ -93,6 +96,18 @@ voices_by_lang = scan_piper_voices()
93
 
94
  # No global piper_voice, load dynamically
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Initialize Kokoro
97
  def initialize_kokoro():
98
  try:
@@ -235,6 +250,25 @@ def generate_kokoro_speech(text, language_code, voice_name):
235
  except Exception as e:
236
  return None, f"Error synthesizing speech: {str(e)}"
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def generate_piper_speech(text, lang, voice):
239
  """
240
  Generate speech from text using Piper TTS with selected voice
@@ -398,6 +432,20 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
398
  piper_language_selection = gr.Radio(
399
  choices=["English", "Chinese"],
400
  value="English",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  label="Language"
402
  )
403
  piper_voice_selection = gr.Dropdown(
@@ -504,6 +552,13 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
504
  outputs=kittentts_audio_output
505
  )
506
 
 
 
 
 
 
 
 
507
  # Connect the Piper generate button to the function
508
  piper_generate_btn.click(
509
  fn=generate_piper_speech,
 
16
  import os
17
  from faster_whisper import WhisperModel
18
  from kokoro import KPipeline
19
+ from src.dia_tts import DiaTTS
20
 
21
  # Model descriptions for better understanding
22
  MODEL_DESCRIPTIONS = {
 
25
  "piper-tts": "Local on-device TTS with dynamic English and Chinese voice selection from Piper models",
26
  "SYSTRAN/faster-whisper": "Faster Whisper transcription with CTranslate2, up to 4x faster than OpenAI Whisper",
27
  "hexgrad/kokoro": "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use",
28
+ "nari-labs/Dia-1.6B": "Ultra-realistic dialogue generation with support for voice cloning and non-verbal expressions",
29
  }
30
 
31
  # Models dictionary
 
35
  "piper-tts": "Piper (no voice cloning)",
36
  "SYSTRAN/faster-whisper": "Faster Whisper",
37
  "hexgrad/kokoro": "Kokoro-82M",
38
+ "nari-labs/Dia-1.6B": "Dia TTS",
39
  }
40
 
41
  original_torch_load = torch.load
 
96
 
97
  # No global piper_voice, load dynamically
98
 
99
+ # Initialize Dia model
100
+ dia_model = None
101
+ def initialize_dia():
102
+ global dia_model
103
+ try:
104
+ dia_model = DiaTTS()
105
+ print("Loaded Dia-1.6B model")
106
+ return dia_model
107
+ except Exception as e:
108
+ print(f"Error loading Dia model: {e}")
109
+ return None
110
+
111
  # Initialize Kokoro
112
  def initialize_kokoro():
113
  try:
 
250
  except Exception as e:
251
  return None, f"Error synthesizing speech: {str(e)}"
252
 
253
+ def generate_dia_speech(text, audio_prompt=None):
254
+ """
255
+ Generate speech from text using Dia TTS with optional audio prompt
256
+
257
+ Args:
258
+ text (str): Text to convert to speech
259
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
260
+
261
+ Returns:
262
+ str: Path to the generated audio file
263
+ """
264
+ # Initialize Dia model if not already initialized
265
+ global dia_model
266
+ if dia_model is None:
267
+ dia_model = initialize_dia()
268
+
269
+ # Generate speech using Dia
270
+ return dia_model.generate_to_file(text, audio_prompt)
271
+
272
  def generate_piper_speech(text, lang, voice):
273
  """
274
  Generate speech from text using Piper TTS with selected voice
 
432
  piper_language_selection = gr.Radio(
433
  choices=["English", "Chinese"],
434
  value="English",
435
+
436
+ # Dia TTS UI
437
+ dia_model_info = gr.HTML(create_model_card("nari-labs/Dia-1.6B"))
438
+
439
+ with gr.Row():
440
+ with gr.Column():
441
+ dia_text_format = gr.Markdown("""
442
+ **Tip:** For dialogue, use [S1] and [S2] tags. For non-verbal expressions, use (laughs), (sighs), etc.
443
+ Example: [S1] Hello there! (laughs) [S2] Hi, how are you doing today?
444
+ """)
445
+ dia_generate_btn = gr.Button("Generate Speech with Dia")
446
+
447
+ with gr.Column():
448
+ dia_audio_output = gr.Audio(label="Generated Speech", type="filepath")
449
  label="Language"
450
  )
451
  piper_voice_selection = gr.Dropdown(
 
552
  outputs=kittentts_audio_output
553
  )
554
 
555
+ # Connect the Dia TTS generate button to the function
556
+ dia_generate_btn.click(
557
+ fn=generate_dia_speech,
558
+ inputs=[text_input, audio_prompt],
559
+ outputs=dia_audio_output
560
+ )
561
+
562
  # Connect the Piper generate button to the function
563
  piper_generate_btn.click(
564
  fn=generate_piper_speech,
requirements.txt CHANGED
@@ -5,8 +5,10 @@ torch
5
  soundfile
6
  https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl
7
  piper-tts
8
- transformers
9
  accelerate
10
  faster-whisper
11
  librosa
12
- kokoro==0.7.16
 
 
 
5
  soundfile
6
  https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl
7
  piper-tts
8
+ transformers>=4.38.0
9
  accelerate
10
  faster-whisper
11
  librosa
12
+ kokoro==0.7.16
13
+ # For Dia TTS model
14
+ git+https://github.com/huggingface/transformers.git
src/dia_tts.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dia TTS model integration for TTS Gallery
3
+ Based on: https://github.com/nari-labs/dia/blob/main/hf.py
4
+ """
5
+
6
+ import tempfile
7
+ import torch
8
+ import soundfile as sf
9
+ from transformers import AutoProcessor, DiaForConditionalGeneration
10
+
11
+ class DiaTTS:
12
+ """
13
+ Wrapper for the Dia TTS model from Nari Labs
14
+ """
15
+ def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"):
16
+ """
17
+ Initialize the Dia TTS model
18
+
19
+ Args:
20
+ model_checkpoint (str): HuggingFace model checkpoint to use
21
+ """
22
+ self.model_checkpoint = model_checkpoint
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # Load processor and model
26
+ self.processor = AutoProcessor.from_pretrained(model_checkpoint)
27
+ self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device)
28
+
29
+ # Default generation parameters
30
+ self.generation_params = {
31
+ "max_new_tokens": 3072,
32
+ "guidance_scale": 3.0,
33
+ "temperature": 1.8,
34
+ "top_p": 0.90,
35
+ "top_k": 45
36
+ }
37
+
38
+ def generate(self, text, audio_prompt=None):
39
+ """
40
+ Generate speech from text using Dia
41
+
42
+ Args:
43
+ text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue.
44
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
45
+
46
+ Returns:
47
+ numpy.ndarray: Generated audio as a numpy array
48
+ int: Sample rate (44100)
49
+ """
50
+ # Format text with speaker tags if not already present
51
+ if not text.startswith("[S1]") and not text.startswith("[S2]"):
52
+ text = f"[S1] {text}"
53
+
54
+ # Prepare inputs
55
+ inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device)
56
+
57
+ # Generate audio
58
+ outputs = self.model.generate(**inputs, **self.generation_params)
59
+
60
+ # Decode outputs
61
+ audio_data = self.processor.batch_decode(outputs)
62
+
63
+ # Return audio data (assuming it's a numpy array) and sample rate
64
+ return audio_data[0], 44100 # Dia uses 44.1kHz sample rate
65
+
66
+ def generate_to_file(self, text, audio_prompt=None):
67
+ """
68
+ Generate speech from text and save to a temporary file
69
+
70
+ Args:
71
+ text (str): Text to convert to speech
72
+ audio_prompt (str, optional): Path to reference audio file for voice cloning
73
+
74
+ Returns:
75
+ str: Path to the generated audio file
76
+ """
77
+ audio_data, sample_rate = self.generate(text, audio_prompt)
78
+
79
+ # Save to a temporary file
80
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
81
+ sf.write(tmp_file.name, audio_data, sample_rate)
82
+ return tmp_file.name