Michael Hu commited on
Commit
3d5e706
·
1 Parent(s): 51e5e89

refactor: remove DiaTTS integration and related UI elements

Browse files
Files changed (2) hide show
  1. app.py +47 -47
  2. src/dia_tts.py +73 -73
app.py CHANGED
@@ -16,7 +16,7 @@ import wave
16
  import os
17
  from faster_whisper import WhisperModel
18
  from kokoro import KPipeline
19
- from src.dia_tts import DiaTTS
20
 
21
  # Model descriptions for better understanding
22
  MODEL_DESCRIPTIONS = {
@@ -97,16 +97,16 @@ voices_by_lang = scan_piper_voices()
97
  # No global piper_voice, load dynamically
98
 
99
  # Initialize Dia model
100
- dia_model = None
101
- def initialize_dia():
102
- global dia_model
103
- try:
104
- dia_model = DiaTTS()
105
- print("Loaded Dia-1.6B model")
106
- return dia_model
107
- except Exception as e:
108
- print(f"Error loading Dia model: {e}")
109
- return None
110
 
111
  # Initialize Kokoro
112
  def initialize_kokoro():
@@ -250,24 +250,24 @@ def generate_kokoro_speech(text, language_code, voice_name):
250
  except Exception as e:
251
  return None, f"Error synthesizing speech: {str(e)}"
252
 
253
- def generate_dia_speech(text, audio_prompt=None):
254
- """
255
- Generate speech from text using Dia TTS with optional audio prompt
256
-
257
- Args:
258
- text (str): Text to convert to speech
259
- audio_prompt (str, optional): Path to reference audio file for voice cloning
260
-
261
- Returns:
262
- str: Path to the generated audio file
263
- """
264
- # Initialize Dia model if not already initialized
265
- global dia_model
266
- if dia_model is None:
267
- dia_model = initialize_dia()
268
-
269
- # Generate speech using Dia
270
- return dia_model.generate_to_file(text, audio_prompt)
271
 
272
  def generate_piper_speech(text, lang, voice):
273
  """
@@ -445,19 +445,19 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
445
  piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
446
  piper_status = gr.Textbox(label="Status", interactive=False)
447
 
448
- # Dia TTS UI
449
- dia_model_info = gr.HTML(create_model_card("nari-labs/Dia-1.6B"))
450
 
451
- with gr.Row():
452
- with gr.Column():
453
- dia_text_format = gr.Markdown("""
454
- **Tip:** For dialogue, use [S1] and [S2] tags. For non-verbal expressions, use (laughs), (sighs), etc.
455
- Example: [S1] Hello there! (laughs) [S2] Hi, how are you doing today?
456
- """)
457
- dia_generate_btn = gr.Button("Generate Speech with Dia")
458
-
459
- with gr.Column():
460
- dia_audio_output = gr.Audio(label="Generated Speech", type="filepath")
461
 
462
  # Faster Whisper section
463
  whisper_model_info = gr.HTML(create_model_card("SYSTRAN/faster-whisper"))
@@ -552,12 +552,12 @@ with gr.Blocks(css=custom_css, title="🎙️ TTS Model Gallery", theme=gr.theme
552
  outputs=kittentts_audio_output
553
  )
554
 
555
- # Connect the Dia TTS generate button to the function
556
- dia_generate_btn.click(
557
- fn=generate_dia_speech,
558
- inputs=[text_input, audio_prompt],
559
- outputs=dia_audio_output
560
- )
561
 
562
  # Connect the Piper generate button to the function
563
  piper_generate_btn.click(
 
16
  import os
17
  from faster_whisper import WhisperModel
18
  from kokoro import KPipeline
19
+ # from src.dia_tts import DiaTTS
20
 
21
  # Model descriptions for better understanding
22
  MODEL_DESCRIPTIONS = {
 
97
  # No global piper_voice, load dynamically
98
 
99
  # Initialize Dia model
100
+ # dia_model = None
101
+ # def initialize_dia():
102
+ # global dia_model
103
+ # try:
104
+ # dia_model = DiaTTS()
105
+ # print("Loaded Dia-1.6B model")
106
+ # return dia_model
107
+ # except Exception as e:
108
+ # print(f"Error loading Dia model: {e}")
109
+ # return None
110
 
111
  # Initialize Kokoro
112
  def initialize_kokoro():
 
250
  except Exception as e:
251
  return None, f"Error synthesizing speech: {str(e)}"
252
 
253
+ # def generate_dia_speech(text, audio_prompt=None):
254
+ # """
255
+ # Generate speech from text using Dia TTS with optional audio prompt
256
+ #
257
+ # Args:
258
+ # text (str): Text to convert to speech
259
+ # audio_prompt (str, optional): Path to reference audio file for voice cloning
260
+ #
261
+ # Returns:
262
+ # str: Path to the generated audio file
263
+ # """
264
+ # # Initialize Dia model if not already initialized
265
+ # global dia_model
266
+ # if dia_model is None:
267
+ # dia_model = initialize_dia()
268
+ #
269
+ # # Generate speech using Dia
270
+ # return dia_model.generate_to_file(text, audio_prompt)
271
 
272
  def generate_piper_speech(text, lang, voice):
273
  """
 
445
  piper_audio_output = gr.Audio(label="Generated Speech", type="filepath")
446
  piper_status = gr.Textbox(label="Status", interactive=False)
447
 
448
+ # Dia TTS UI (commented out for now)
449
+ # dia_model_info = gr.HTML(create_model_card("nari-labs/Dia-1.6B"))
450
 
451
+ # with gr.Row():
452
+ # with gr.Column():
453
+ # dia_text_format = gr.Markdown("""
454
+ # **Tip:** For dialogue, use [S1] and [S2] tags. For non-verbal expressions, use (laughs), (sighs), etc.
455
+ # Example: [S1] Hello there! (laughs) [S2] Hi, how are you doing today?
456
+ # """)
457
+ # dia_generate_btn = gr.Button("Generate Speech with Dia")
458
+ #
459
+ # with gr.Column():
460
+ # dia_audio_output = gr.Audio(label="Generated Speech", type="filepath")
461
 
462
  # Faster Whisper section
463
  whisper_model_info = gr.HTML(create_model_card("SYSTRAN/faster-whisper"))
 
552
  outputs=kittentts_audio_output
553
  )
554
 
555
+ # Connect the Dia TTS generate button to the function (commented out for now)
556
+ # dia_generate_btn.click(
557
+ # fn=generate_dia_speech,
558
+ # inputs=[text_input, audio_prompt],
559
+ # outputs=dia_audio_output
560
+ # )
561
 
562
  # Connect the Piper generate button to the function
563
  piper_generate_btn.click(
src/dia_tts.py CHANGED
@@ -6,77 +6,77 @@ Based on: https://github.com/nari-labs/dia/blob/main/hf.py
6
  import tempfile
7
  import torch
8
  import soundfile as sf
9
- from transformers import AutoProcessor, DiaForConditionalGeneration
10
 
11
- class DiaTTS:
12
- """
13
- Wrapper for the Dia TTS model from Nari Labs
14
- """
15
- def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"):
16
- """
17
- Initialize the Dia TTS model
18
-
19
- Args:
20
- model_checkpoint (str): HuggingFace model checkpoint to use
21
- """
22
- self.model_checkpoint = model_checkpoint
23
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
-
25
- # Load processor and model
26
- self.processor = AutoProcessor.from_pretrained(model_checkpoint)
27
- self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device)
28
-
29
- # Default generation parameters
30
- self.generation_params = {
31
- "max_new_tokens": 3072,
32
- "guidance_scale": 3.0,
33
- "temperature": 1.8,
34
- "top_p": 0.90,
35
- "top_k": 45
36
- }
37
-
38
- def generate(self, text, audio_prompt=None):
39
- """
40
- Generate speech from text using Dia
41
-
42
- Args:
43
- text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue.
44
- audio_prompt (str, optional): Path to reference audio file for voice cloning
45
-
46
- Returns:
47
- numpy.ndarray: Generated audio as a numpy array
48
- int: Sample rate (44100)
49
- """
50
- # Format text with speaker tags if not already present
51
- if not text.startswith("[S1]") and not text.startswith("[S2]"):
52
- text = f"[S1] {text}"
53
-
54
- # Prepare inputs
55
- inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device)
56
-
57
- # Generate audio
58
- outputs = self.model.generate(**inputs, **self.generation_params)
59
-
60
- # Decode outputs
61
- audio_data = self.processor.batch_decode(outputs)
62
-
63
- # Return audio data (assuming it's a numpy array) and sample rate
64
- return audio_data[0], 44100 # Dia uses 44.1kHz sample rate
65
-
66
- def generate_to_file(self, text, audio_prompt=None):
67
- """
68
- Generate speech from text and save to a temporary file
69
-
70
- Args:
71
- text (str): Text to convert to speech
72
- audio_prompt (str, optional): Path to reference audio file for voice cloning
73
-
74
- Returns:
75
- str: Path to the generated audio file
76
- """
77
- audio_data, sample_rate = self.generate(text, audio_prompt)
78
-
79
- # Save to a temporary file
80
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
81
- sf.write(tmp_file.name, audio_data, sample_rate)
82
- return tmp_file.name
 
6
  import tempfile
7
  import torch
8
  import soundfile as sf
9
+ # from transformers import AutoProcessor, DiaForConditionalGeneration
10
 
11
+ # class DiaTTS:
12
+ # """
13
+ # Wrapper for the Dia TTS model from Nari Labs
14
+ # """
15
+ # def __init__(self, model_checkpoint="nari-labs/Dia-1.6B"):
16
+ # """
17
+ # Initialize the Dia TTS model
18
+ #
19
+ # Args:
20
+ # model_checkpoint (str): HuggingFace model checkpoint to use
21
+ # """
22
+ # self.model_checkpoint = model_checkpoint
23
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ #
25
+ # # Load processor and model
26
+ # self.processor = AutoProcessor.from_pretrained(model_checkpoint)
27
+ # self.model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(self.device)
28
+ #
29
+ # # Default generation parameters
30
+ # self.generation_params = {
31
+ # "max_new_tokens": 3072,
32
+ # "guidance_scale": 3.0,
33
+ # "temperature": 1.8,
34
+ # "top_p": 0.90,
35
+ # "top_k": 45
36
+ # }
37
+ #
38
+ # def generate(self, text, audio_prompt=None):
39
+ # """
40
+ # Generate speech from text using Dia
41
+ #
42
+ # Args:
43
+ # text (str): Text to convert to speech. Should use [S1] and [S2] tags for dialogue.
44
+ # audio_prompt (str, optional): Path to reference audio file for voice cloning
45
+ #
46
+ # Returns:
47
+ # numpy.ndarray: Generated audio as a numpy array
48
+ # int: Sample rate (44100)
49
+ # """
50
+ # # Format text with speaker tags if not already present
51
+ # if not text.startswith("[S1]") and not text.startswith("[S2]"):
52
+ # text = f"[S1] {text}"
53
+ #
54
+ # # Prepare inputs
55
+ # inputs = self.processor(text=[text], padding=True, return_tensors="pt").to(self.device)
56
+ #
57
+ # # Generate audio
58
+ # outputs = self.model.generate(**inputs, **self.generation_params)
59
+ #
60
+ # # Decode outputs
61
+ # audio_data = self.processor.batch_decode(outputs)
62
+ #
63
+ # # Return audio data (assuming it's a numpy array) and sample rate
64
+ # return audio_data[0], 44100 # Dia uses 44.1kHz sample rate
65
+ #
66
+ # def generate_to_file(self, text, audio_prompt=None):
67
+ # """
68
+ # Generate speech from text and save to a temporary file
69
+ #
70
+ # Args:
71
+ # text (str): Text to convert to speech
72
+ # audio_prompt (str, optional): Path to reference audio file for voice cloning
73
+ #
74
+ # Returns:
75
+ # str: Path to the generated audio file
76
+ # """
77
+ # audio_data, sample_rate = self.generate(text, audio_prompt)
78
+ #
79
+ # # Save to a temporary file
80
+ # with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
81
+ # sf.write(tmp_file.name, audio_data, sample_rate)
82
+ # return tmp_file.name