Steveeeeeeen HF Staff commited on
Commit
3cf0e6f
·
1 Parent(s): 4d768c3

add step audio demo

Browse files
Files changed (1) hide show
  1. app.py +142 -4
app.py CHANGED
@@ -1,7 +1,145 @@
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
1
+ import tempfile
2
+ import traceback
3
+ from pathlib import Path
4
+
5
  import gradio as gr
6
 
7
+ def save_tmp_audio(audio, cache_dir):
8
+ with tempfile.NamedTemporaryFile(
9
+ dir=cache_dir, delete=False, suffix=".wav"
10
+ ) as temp_audio:
11
+ temp_audio.write(audio)
12
+ return temp_audio.name
13
+
14
+ def add_message(chatbot, history, mic, text):
15
+ if not mic and not text:
16
+ return chatbot, history, "Input is empty"
17
+
18
+ if text:
19
+ chatbot.append({"role": "user", "content": text})
20
+ history.append({"role": "human", "content": text})
21
+ elif mic and Path(mic).exists():
22
+ chatbot.append({"role": "user", "content": {"path": mic}})
23
+ history.append({"role": "human", "content": [{"type":"audio", "audio": mic}]})
24
+
25
+ print(f"{history=}")
26
+ return chatbot, history, None
27
+
28
+ def reset_state(system_prompt):
29
+ return [], [{"role": "system", "content": system_prompt}]
30
+
31
+ @spaces.GPU
32
+ def predict(chatbot, history, audio_model, token2wav, prompt_wav, cache_dir):
33
+ try:
34
+ history.append({"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}], "eot": False})
35
+ tokens, text, audio = audio_model(history, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, do_sample=True)
36
+ print(f"predict {text=}")
37
+ audio = token2wav(audio, prompt_wav)
38
+ audio_path = save_tmp_audio(audio, cache_dir)
39
+ chatbot.append({"role": "assistant", "content": {"path": audio_path}})
40
+ history[-1]["content"].append({"type": "token", "token": tokens})
41
+ history[-1]["eot"] = True
42
+ except Exception:
43
+ print(traceback.format_exc())
44
+ gr.Warning(f"Some error happend, please try again.")
45
+ return chatbot, history
46
+
47
+ def _launch_demo(args, audio_model, token2wav):
48
+ with gr.Blocks(delete_cache=(86400, 86400)) as demo:
49
+ gr.Markdown("""<center><font size=8>Step Audio 2 Demo</center>""")
50
+ with gr.Row():
51
+ system_prompt = gr.Textbox(
52
+ label="System Prompt",
53
+ value="你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n今天是2025年8月29日,星期五\n请用默认女声与用户交流。",
54
+ lines=2
55
+ )
56
+ chatbot = gr.Chatbot(
57
+ elem_id="chatbot",
58
+ #avatar_images=["assets/user.png", "assets/assistant.png"],
59
+ min_height=800,
60
+ type="messages",
61
+ )
62
+ history = gr.State([{"role": "system", "content": system_prompt.value}])
63
+ mic = gr.Audio(type="filepath")
64
+ text = gr.Textbox(placeholder="Enter message ...")
65
+
66
+ with gr.Row():
67
+ clean_btn = gr.Button("🧹 Clear History (清除历史)")
68
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
69
+ submit_btn = gr.Button("🚀 Submit")
70
+
71
+ def on_submit(chatbot, history, mic, text):
72
+ chatbot, history, error = add_message(
73
+ chatbot, history, mic, text
74
+ )
75
+ if error:
76
+ gr.Warning(error) # 显示警告消息
77
+ return chatbot, history, None, None
78
+ else:
79
+ chatbot, history = predict(chatbot, history, audio_model, token2wav, args.prompt_wav, args.cache_dir)
80
+ return chatbot, history, None, None
81
+
82
+ submit_btn.click(
83
+ fn=on_submit,
84
+ inputs=[chatbot, history, mic, text],
85
+ outputs=[chatbot, history, mic, text],
86
+ concurrency_limit=4,
87
+ concurrency_id="gpu_queue",
88
+ )
89
+
90
+ clean_btn.click(
91
+ fn=reset_state,
92
+ inputs=[system_prompt],
93
+ outputs=[chatbot, history],
94
+ #show_progress=True,
95
+ )
96
+
97
+ def regenerate(chatbot, history):
98
+ while chatbot and chatbot[-1]["role"] == "assistant":
99
+ chatbot.pop()
100
+ while history and history[-1]["role"] == "assistant":
101
+ print(f"discard {history[-1]}")
102
+ history.pop()
103
+ return predict(chatbot, history, audio_model, token2wav, args.prompt_wav, args.cache_dir)
104
+
105
+ regen_btn.click(
106
+ regenerate,
107
+ [chatbot, history],
108
+ [chatbot, history],
109
+ #show_progress=True,
110
+ concurrency_id="gpu_queue",
111
+ )
112
+
113
+ demo.queue().launch(
114
+ server_port=args.server_port,
115
+ server_name=args.server_name,
116
+ )
117
+
118
+
119
+ if __name__ == "__main__":
120
+ import os
121
+ from argparse import ArgumentParser
122
+
123
+ from stepaudio2 import StepAudio2
124
+ from token2wav import Token2wav
125
+
126
+ parser = ArgumentParser()
127
+ parser.add_argument("--model-path", type=str, default='Step-Audio-2-mini', help="Model path.")
128
+ parser.add_argument(
129
+ "--server-port", type=int, default=7860, help="Demo server port."
130
+ )
131
+ parser.add_argument(
132
+ "--server-name", type=str, default="0.0.0.0", help="Demo server name."
133
+ )
134
+ parser.add_argument(
135
+ "--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant."
136
+ )
137
+ parser.add_argument(
138
+ "--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory."
139
+ )
140
+ args = parser.parse_args()
141
+ os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
142
 
143
+ audio_model = StepAudio2(args.model_path)
144
+ token2wav = Token2wav(f"{args.model_path}/token2wav")
145
+ _launch_demo(args, audio_model, token2wav)