malak2025 commited on
Commit
4652370
·
verified ·
1 Parent(s): 10e578f

Upload 2 files

Browse files

add app.py and requirements).

Files changed (2) hide show
  1. app.py +316 -0
  2. requirements.txt +37 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import queue
3
+ from huggingface_hub import snapshot_download
4
+ import numpy as np
5
+ import wave
6
+ import io
7
+ import gc
8
+ from typing import Callable
9
+
10
+ # Download if not exists
11
+ os.makedirs("checkpoints", exist_ok=True)
12
+ snapshot_download(repo_id="fishaudio/openaudio-s1-mini", local_dir="./checkpoints/openaudio-s1-mini")
13
+
14
+ print("All checkpoints downloaded")
15
+
16
+ import html
17
+ import os
18
+ from argparse import ArgumentParser
19
+ from pathlib import Path
20
+
21
+ import gradio as gr
22
+ import torch
23
+ import torchaudio
24
+
25
+ torchaudio.set_audio_backend("soundfile")
26
+
27
+ from loguru import logger
28
+ from fish_speech.i18n import i18n
29
+ from fish_speech.inference_engine import TTSInferenceEngine
30
+ from fish_speech.models.dac.inference import load_model as load_decoder_model
31
+ from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
32
+ from tools.webui.inference import get_inference_wrapper
33
+ from fish_speech.utils.schema import ServeTTSRequest
34
+
35
+ # Make einx happy
36
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
37
+
38
+
39
+ HEADER_MD = """# OpenAudio S1
40
+
41
+ ## The demo in this space is OpenAudio S1, Please check [Fish Audio](https://fish.audio) for the best model.
42
+ ## 该 Demo 为 OpenAudio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
43
+
44
+ A text-to-speech model based on DAC & Qwen3 developed by [Fish Audio](https://fish.audio).
45
+ 由 [Fish Audio](https://fish.audio) 研发的 DAC & Qwen3 多语种语音合成.
46
+
47
+ You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/openaudio-s1-mini).
48
+ 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/openaudio-s1-mini) 找到模型.
49
+
50
+ Related code and weights are released under CC BY-NC-SA 4.0 License.
51
+ 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
52
+
53
+ We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
54
+ 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
55
+
56
+ The model running in this WebUI is OpenAudio S1 Mini.
57
+ 在此 WebUI 中运行的模型是 OpenAudio S1 Mini.
58
+ """
59
+
60
+ TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
61
+
62
+ try:
63
+ import spaces
64
+
65
+ GPU_DECORATOR = spaces.GPU
66
+ except ImportError:
67
+
68
+ def GPU_DECORATOR(func):
69
+ def wrapper(*args, **kwargs):
70
+ return func(*args, **kwargs)
71
+
72
+ return wrapper
73
+
74
+ def build_html_error_message(error):
75
+ return f"""
76
+ <div style="color: red;
77
+ font-weight: bold;">
78
+ {html.escape(str(error))}
79
+ </div>
80
+ """
81
+
82
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
83
+ buffer = io.BytesIO()
84
+
85
+ with wave.open(buffer, "wb") as wav_file:
86
+ wav_file.setnchannels(channels)
87
+ wav_file.setsampwidth(bit_depth // 8)
88
+ wav_file.setframerate(sample_rate)
89
+
90
+ wav_header_bytes = buffer.getvalue()
91
+ buffer.close()
92
+ return wav_header_bytes
93
+
94
+
95
+ def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
96
+ with gr.Blocks(theme=gr.themes.Base()) as app:
97
+ gr.Markdown(HEADER_MD)
98
+
99
+ # Use light theme by default
100
+ app.load(
101
+ None,
102
+ None,
103
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
104
+ % theme,
105
+ )
106
+
107
+ # Inference
108
+ with gr.Row():
109
+ with gr.Column(scale=3):
110
+ text = gr.Textbox(
111
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
112
+ )
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ with gr.Tab(label=i18n("Advanced Config")):
117
+ with gr.Row():
118
+ chunk_length = gr.Slider(
119
+ label=i18n("Iterative Prompt Length, 0 means off"),
120
+ minimum=0,
121
+ maximum=500,
122
+ value=0,
123
+ step=8,
124
+ )
125
+
126
+ max_new_tokens = gr.Slider(
127
+ label=i18n(
128
+ "Maximum tokens per batch, 0 means no limit"
129
+ ),
130
+ minimum=0,
131
+ maximum=2048,
132
+ value=0,
133
+ step=8,
134
+ )
135
+
136
+ with gr.Row():
137
+ top_p = gr.Slider(
138
+ label="Top-P",
139
+ minimum=0.7,
140
+ maximum=0.95,
141
+ value=0.9,
142
+ step=0.01,
143
+ )
144
+
145
+ repetition_penalty = gr.Slider(
146
+ label=i18n("Repetition Penalty"),
147
+ minimum=1,
148
+ maximum=1.2,
149
+ value=1.1,
150
+ step=0.01,
151
+ )
152
+
153
+ with gr.Row():
154
+ temperature = gr.Slider(
155
+ label="Temperature",
156
+ minimum=0.7,
157
+ maximum=1.0,
158
+ value=0.9,
159
+ step=0.01,
160
+ )
161
+ seed = gr.Number(
162
+ label="Seed",
163
+ info="0 means randomized inference, otherwise deterministic",
164
+ value=0,
165
+ )
166
+
167
+ with gr.Tab(label=i18n("Reference Audio")):
168
+ with gr.Row():
169
+ gr.Markdown(
170
+ i18n(
171
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
172
+ )
173
+ )
174
+ with gr.Row():
175
+ reference_id = gr.Textbox(
176
+ label=i18n("Reference ID"),
177
+ placeholder="Leave empty to use uploaded references",
178
+ )
179
+
180
+ with gr.Row():
181
+ use_memory_cache = gr.Radio(
182
+ label=i18n("Use Memory Cache"),
183
+ choices=["on", "off"],
184
+ value="on",
185
+ )
186
+
187
+ with gr.Row():
188
+ reference_audio = gr.Audio(
189
+ label=i18n("Reference Audio"),
190
+ type="filepath",
191
+ )
192
+ with gr.Row():
193
+ reference_text = gr.Textbox(
194
+ label=i18n("Reference Text"),
195
+ lines=1,
196
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
197
+ value="",
198
+ )
199
+
200
+ with gr.Column(scale=3):
201
+ with gr.Row():
202
+ error = gr.HTML(
203
+ label=i18n("Error Message"),
204
+ visible=True,
205
+ )
206
+ with gr.Row():
207
+ audio = gr.Audio(
208
+ label=i18n("Generated Audio"),
209
+ type="numpy",
210
+ interactive=False,
211
+ visible=True,
212
+ )
213
+
214
+ with gr.Row():
215
+ with gr.Column(scale=3):
216
+ generate = gr.Button(
217
+ value="\U0001f3a7 " + i18n("Generate"),
218
+ variant="primary",
219
+ )
220
+
221
+ # Submit
222
+ generate.click(
223
+ inference_fct,
224
+ [
225
+ text,
226
+ reference_id,
227
+ reference_audio,
228
+ reference_text,
229
+ max_new_tokens,
230
+ chunk_length,
231
+ top_p,
232
+ repetition_penalty,
233
+ temperature,
234
+ seed,
235
+ use_memory_cache,
236
+ ],
237
+ [audio, error],
238
+ concurrency_limit=1,
239
+ )
240
+
241
+ return app
242
+
243
+ def parse_args():
244
+ parser = ArgumentParser()
245
+ parser.add_argument(
246
+ "--llama-checkpoint-path",
247
+ type=Path,
248
+ default="checkpoints/openaudio-s1-mini",
249
+ )
250
+ parser.add_argument(
251
+ "--decoder-checkpoint-path",
252
+ type=Path,
253
+ default="checkpoints/openaudio-s1-mini/codec.pth",
254
+ )
255
+ parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
256
+ parser.add_argument("--device", type=str, default="cuda")
257
+ parser.add_argument("--half", action="store_true")
258
+ parser.add_argument("--compile", action="store_true",default=True)
259
+ parser.add_argument("--max-gradio-length", type=int, default=0)
260
+ parser.add_argument("--theme", type=str, default="dark")
261
+
262
+ return parser.parse_args()
263
+
264
+
265
+ if __name__ == "__main__":
266
+ args = parse_args()
267
+ args.precision = torch.half if args.half else torch.bfloat16
268
+
269
+ logger.info("Loading Llama model...")
270
+ llama_queue = launch_thread_safe_queue(
271
+ checkpoint_path=args.llama_checkpoint_path,
272
+ device=args.device,
273
+ precision=args.precision,
274
+ compile=args.compile,
275
+ )
276
+ logger.info("Llama model loaded, loading VQ-GAN model...")
277
+
278
+ decoder_model = load_decoder_model(
279
+ config_name=args.decoder_config_name,
280
+ checkpoint_path=args.decoder_checkpoint_path,
281
+ device=args.device,
282
+ )
283
+
284
+ logger.info("Decoder model loaded, warming up...")
285
+
286
+ # Create the inference engine
287
+ inference_engine = TTSInferenceEngine(
288
+ llama_queue=llama_queue,
289
+ decoder_model=decoder_model,
290
+ compile=args.compile,
291
+ precision=args.precision,
292
+ )
293
+
294
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
295
+ list(
296
+ inference_engine.inference(
297
+ ServeTTSRequest(
298
+ text="Hello world.",
299
+ references=[],
300
+ reference_id=None,
301
+ max_new_tokens=1024,
302
+ chunk_length=200,
303
+ top_p=0.7,
304
+ repetition_penalty=1.5,
305
+ temperature=0.7,
306
+ format="wav",
307
+ )
308
+ )
309
+ )
310
+
311
+ logger.info("Warming up done, launching the web UI...")
312
+
313
+ inference_fct = get_inference_wrapper(inference_engine)
314
+
315
+ app = build_app(inference_fct, args.theme)
316
+ app.queue(api_open=True).launch(show_error=True, show_api=True)
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ transformers>=4.35.2
4
+ datasets>=2.14.5
5
+ lightning>=2.1.0
6
+ hydra-core>=1.3.2
7
+ tensorboard>=2.14.1
8
+ natsort>=8.4.0
9
+ einops>=0.7.0
10
+ librosa>=0.10.1
11
+ rich>=13.5.3
12
+ gradio>=4.0.0
13
+ wandb>=0.15.11
14
+ grpcio>=1.58.0
15
+ kui>=1.6.0
16
+ zibai-server>=0.9.0
17
+ loguru>=0.6.0
18
+ loralib>=0.1.2
19
+ natsort>=8.4.0
20
+ pyrootutils>=1.0.4
21
+ descript-audiotools
22
+ vector_quantize_pytorch==1.14.24
23
+ resampy>=0.4.3
24
+ spaces>=0.26.1
25
+ einx[torch]==0.2.2
26
+ opencc
27
+ faster-whisper
28
+ ormsgpack
29
+ ffmpeg
30
+ soundfile
31
+ cachetools
32
+ funasr
33
+ silero-vad
34
+ tiktoken
35
+ numpy
36
+ huggingface_hub
37
+ git+https://github.com/descriptinc/descript-audio-codec