NariLabs commited on
Commit
14b2a6c
·
verified ·
1 Parent(s): 698b79c

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -264
app.py DELETED
@@ -1,264 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import contextlib
4
- import io
5
- import os
6
- from pathlib import Path
7
- from typing import List, Tuple
8
-
9
- import gradio as gr
10
- import torch
11
- import spaces
12
-
13
- from dia2 import Dia2, GenerationConfig, SamplingConfig
14
-
15
- DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B")
16
- MAX_TURNS = 10
17
- INITIAL_TURNS = 2
18
-
19
- _dia: Dia2 | None = None
20
-
21
-
22
- def _get_dia() -> Dia2:
23
- global _dia
24
- if _dia is None:
25
- _dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16")
26
- return _dia
27
-
28
-
29
- def _concat_script(turn_count: int, turn_values: List[str]) -> str:
30
- lines: List[str] = []
31
- for idx in range(min(turn_count, len(turn_values))):
32
- text = (turn_values[idx] or "").strip()
33
- if not text:
34
- continue
35
- speaker = "[S1]" if idx % 2 == 0 else "[S2]"
36
- lines.append(f"{speaker} {text}")
37
- return "\n".join(lines)
38
-
39
-
40
- EXAMPLES: dict[str, dict[str, List[str] | str | None]] = {
41
- "Intro": {
42
- "turns": [
43
- "Hello Dia2 fans! Today we're unveiling the new open TTS model.",
44
- "Sounds exciting. Can you show a sample right now?",
45
- "Absolutely. (laughs) Just press generate.",
46
- ],
47
- "voice_s1": "example_prefix1.wav",
48
- "voice_s2": "example_prefix2.wav",
49
- },
50
- "Customer Support": {
51
- "turns": [
52
- "Thanks for calling. How can I help you today?",
53
- "My parcel never arrived and it's been two weeks.",
54
- "I'm sorry about that. Let me check your tracking number.",
55
- "Appreciate it. I really need that package soon.",
56
- ],
57
- "voice_s1": "example_prefix1.wav",
58
- "voice_s2": "example_prefix2.wav",
59
- },
60
- }
61
-
62
-
63
- def _apply_turn_visibility(count: int) -> List[gr.Update]:
64
- return [gr.update(visible=i < count) for i in range(MAX_TURNS)]
65
-
66
-
67
- def _add_turn(count: int):
68
- count = min(count + 1, MAX_TURNS)
69
- return (count, *_apply_turn_visibility(count))
70
-
71
-
72
- def _remove_turn(count: int):
73
- count = max(1, count - 1)
74
- return (count, *_apply_turn_visibility(count))
75
-
76
-
77
- def _load_example(name: str, count: int):
78
- data = EXAMPLES.get(name)
79
- if not data:
80
- return (count, *_apply_turn_visibility(count), None, None)
81
- turns = data.get("turns", [])
82
- voice_s1_path = data.get("voice_s1")
83
- voice_s2_path = data.get("voice_s2")
84
- new_count = min(len(turns), MAX_TURNS)
85
- updates: List[gr.Update] = []
86
- for idx in range(MAX_TURNS):
87
- if idx < new_count:
88
- updates.append(gr.update(value=turns[idx], visible=True))
89
- else:
90
- updates.append(gr.update(value="", visible=idx < INITIAL_TURNS))
91
- return (new_count, *updates, voice_s1_path, voice_s2_path)
92
-
93
-
94
- def _prepare_prefix(file_path: str | None) -> str | None:
95
- if not file_path:
96
- return None
97
- path = Path(file_path)
98
- if not path.exists():
99
- return None
100
- return str(path)
101
-
102
-
103
- @spaces.GPU(duration=100)
104
- def generate_audio(
105
- turn_count: int,
106
- *inputs,
107
- ):
108
- turn_values = list(inputs[:MAX_TURNS])
109
- voice_s1 = inputs[MAX_TURNS]
110
- voice_s2 = inputs[MAX_TURNS + 1]
111
- cfg_scale = float(inputs[MAX_TURNS + 2])
112
- text_temperature = float(inputs[MAX_TURNS + 3])
113
- audio_temperature = float(inputs[MAX_TURNS + 4])
114
- text_top_k = int(inputs[MAX_TURNS + 5])
115
- audio_top_k = int(inputs[MAX_TURNS + 6])
116
- include_prefix = bool(inputs[MAX_TURNS + 7])
117
-
118
- script = _concat_script(turn_count, turn_values)
119
- if not script.strip():
120
- raise gr.Error("Please enter at least one non-empty speaker turn.")
121
-
122
- dia = _get_dia()
123
- config = GenerationConfig(
124
- cfg_scale=cfg_scale,
125
- text=SamplingConfig(temperature=text_temperature, top_k=text_top_k),
126
- audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k),
127
- use_cuda_graph=True,
128
- )
129
- kwargs = {
130
- "prefix_speaker_1": _prepare_prefix(voice_s1),
131
- "prefix_speaker_2": _prepare_prefix(voice_s2),
132
- "include_prefix": include_prefix,
133
- }
134
- buffer = io.StringIO()
135
- with contextlib.redirect_stdout(buffer):
136
- result = dia.generate(
137
- script,
138
- config=config,
139
- output_wav=None,
140
- verbose=True,
141
- **kwargs,
142
- )
143
- waveform = result.waveform.detach().cpu().numpy()
144
- sample_rate = result.sample_rate
145
- timestamps = result.timestamps
146
- log_text = buffer.getvalue().strip()
147
- table = [[w, round(t, 3)] for w, t in timestamps]
148
- return (sample_rate, waveform), table, log_text or "Generation finished."
149
-
150
-
151
- def build_interface() -> gr.Blocks:
152
- with gr.Blocks(
153
- title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}"
154
- ) as demo:
155
- gr.Markdown(
156
- """## Dia2 — Open TTS Model
157
- Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default)."""
158
- )
159
- turn_state = gr.State(INITIAL_TURNS)
160
- with gr.Row(equal_height=True):
161
- example_dropdown = gr.Dropdown(
162
- choices=["(select example)"] + list(EXAMPLES.keys()),
163
- label="Examples",
164
- value="(select example)",
165
- )
166
- with gr.Row(equal_height=True):
167
- with gr.Column(scale=1):
168
- with gr.Group():
169
- gr.Markdown("### Script")
170
- controls = []
171
- for idx in range(MAX_TURNS):
172
- speaker = "[S1]" if idx % 2 == 0 else "[S2]"
173
- box = gr.Textbox(
174
- label=f"{speaker} turn {idx + 1}",
175
- lines=2,
176
- elem_classes=["compact-turn"],
177
- placeholder=f"Enter dialogue for {speaker}…",
178
- visible=idx < INITIAL_TURNS,
179
- )
180
- controls.append(box)
181
- with gr.Row():
182
- add_btn = gr.Button("Add Turn")
183
- remove_btn = gr.Button("Remove Turn")
184
- with gr.Group():
185
- gr.Markdown("### Voice Prompts")
186
- with gr.Row():
187
- voice_s1 = gr.File(
188
- label="[S1] voice (wav/mp3)", type="filepath"
189
- )
190
- voice_s2 = gr.File(
191
- label="[S2] voice (wav/mp3)", type="filepath"
192
- )
193
- with gr.Group():
194
- gr.Markdown("### Sampling")
195
- cfg_scale = gr.Slider(
196
- 1.0, 8.0, value=6.0, step=0.1, label="CFG Scale"
197
- )
198
- with gr.Group():
199
- gr.Markdown("#### Text Sampling")
200
- text_temperature = gr.Slider(
201
- 0.1, 1.5, value=0.6, step=0.05, label="Text Temperature"
202
- )
203
- text_top_k = gr.Slider(
204
- 1, 200, value=50, step=1, label="Text Top-K"
205
- )
206
- with gr.Group():
207
- gr.Markdown("#### Audio Sampling")
208
- audio_temperature = gr.Slider(
209
- 0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature"
210
- )
211
- audio_top_k = gr.Slider(
212
- 1, 200, value=50, step=1, label="Audio Top-K"
213
- )
214
- include_prefix = gr.Checkbox(
215
- label="Keep prefix audio in output", value=False
216
- )
217
- generate_btn = gr.Button("Generate", variant="primary")
218
- with gr.Column(scale=1):
219
- gr.Markdown("### Output")
220
- audio_out = gr.Audio(label="Waveform", interactive=False)
221
- timestamps = gr.Dataframe(
222
- headers=["word", "seconds"], label="Timestamps"
223
- )
224
- log_box = gr.Textbox(label="Logs", lines=8)
225
-
226
- add_btn.click(
227
- lambda c: _add_turn(c),
228
- inputs=turn_state,
229
- outputs=[turn_state, *controls],
230
- )
231
- remove_btn.click(
232
- lambda c: _remove_turn(c),
233
- inputs=turn_state,
234
- outputs=[turn_state, *controls],
235
- )
236
- example_dropdown.change(
237
- lambda name, c: _load_example(name, c),
238
- inputs=[example_dropdown, turn_state],
239
- outputs=[turn_state, *controls, voice_s1, voice_s2],
240
- )
241
-
242
- generate_btn.click(
243
- generate_audio,
244
- inputs=[
245
- turn_state,
246
- *controls,
247
- voice_s1,
248
- voice_s2,
249
- cfg_scale,
250
- text_temperature,
251
- audio_temperature,
252
- text_top_k,
253
- audio_top_k,
254
- include_prefix,
255
- ],
256
- outputs=[audio_out, timestamps, log_box],
257
- )
258
- return demo
259
-
260
-
261
- if __name__ == "__main__":
262
- app = build_interface()
263
- app.queue(default_concurrency_limit=1)
264
- app.launch(share=True)