Steveeeeeeen HF Staff commited on
Commit
676ffac
·
verified ·
1 Parent(s): e8f2ced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -72
app.py CHANGED
@@ -1,31 +1,28 @@
1
  import os
2
  import shlex
3
  import subprocess
 
 
 
 
4
 
5
- # install requirements
6
  os.system("pip install -r requirements.txt")
7
- # wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/blob/main/token2wav/campplus.onnx in token2wav folder
8
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/campplus.onnx -P token2wav")
9
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.pt -P token2wav")
10
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.yaml -P token2wav")
11
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/hift.pt -P token2wav")
12
 
13
-
14
- # get hf token
15
  hf_token = os.getenv("HF_TOKEN", None)
16
  os.environ["HF_TOKEN"] = hf_token
17
 
18
- import tempfile
19
- import traceback
20
- from pathlib import Path
21
  import spaces
22
  import gradio as gr
23
 
24
- def save_tmp_audio(audio, cache_dir):
25
- with tempfile.NamedTemporaryFile(
26
- dir=cache_dir, delete=False, suffix=".wav"
27
- ) as temp_audio:
28
- temp_audio.write(audio)
29
  return temp_audio.name
30
 
31
  def add_message(chatbot, history, mic, text):
@@ -37,68 +34,122 @@ def add_message(chatbot, history, mic, text):
37
  history.append({"role": "human", "content": text})
38
  elif mic and Path(mic).exists():
39
  chatbot.append({"role": "user", "content": {"path": mic}})
40
- history.append({"role": "human", "content": [{"type":"audio", "audio": mic}]})
41
 
42
- print(f"{history=}")
43
  return chatbot, history, None
44
 
45
  def reset_state(system_prompt):
46
  return [], [{"role": "system", "content": system_prompt}]
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @spaces.GPU
49
- def predict(chatbot, history, audio_model, token2wav, prompt_wav, cache_dir):
 
 
 
 
50
  try:
51
- history.append({"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}], "eot": False})
52
- tokens, text, audio = audio_model(history, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, do_sample=True)
53
- print(f"predict {text=}")
54
- audio = token2wav(audio, prompt_wav)
55
- audio_path = save_tmp_audio(audio, cache_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
 
 
57
  history[-1]["content"].append({"type": "token", "token": tokens})
58
  history[-1]["eot"] = True
 
59
  except Exception:
60
  print(traceback.format_exc())
61
- gr.Warning(f"Some error happend, please try again.")
62
  return chatbot, history
63
 
64
- def _launch_demo(args, audio_model, token2wav):
65
  with gr.Blocks(delete_cache=(86400, 86400)) as demo:
66
- gr.Markdown("""<center><font size=8>Step Audio 2 Demo</center>""")
 
67
  with gr.Row():
68
  system_prompt = gr.Textbox(
69
  label="System Prompt",
70
- value="你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n今天是2025年8月29日,星期五\n请用默认女声与用户交流。",
 
 
 
 
 
71
  lines=2
72
  )
73
- chatbot = gr.Chatbot(
74
- elem_id="chatbot",
75
- #avatar_images=["assets/user.png", "assets/assistant.png"],
76
- min_height=800,
77
- type="messages",
78
- )
79
  history = gr.State([{"role": "system", "content": system_prompt.value}])
 
 
80
  mic = gr.Audio(type="filepath")
81
  text = gr.Textbox(placeholder="Enter message ...")
82
 
 
 
 
 
 
 
83
  with gr.Row():
84
  clean_btn = gr.Button("🧹 Clear History (清除历史)")
85
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
86
  submit_btn = gr.Button("🚀 Submit")
87
 
88
- def on_submit(chatbot, history, mic, text):
89
- chatbot, history, error = add_message(
90
- chatbot, history, mic, text
91
- )
92
  if error:
93
- gr.Warning(error) # 显示警告消息
94
- return chatbot, history, None, None
95
- else:
96
- chatbot, history = predict(chatbot, history, audio_model, token2wav, args.prompt_wav, args.cache_dir)
97
  return chatbot, history, None, None
 
 
98
 
99
  submit_btn.click(
100
  fn=on_submit,
101
- inputs=[chatbot, history, mic, text],
102
  outputs=[chatbot, history, mic, text],
103
  concurrency_limit=4,
104
  concurrency_id="gpu_queue",
@@ -108,55 +159,39 @@ def _launch_demo(args, audio_model, token2wav):
108
  fn=reset_state,
109
  inputs=[system_prompt],
110
  outputs=[chatbot, history],
111
- #show_progress=True,
112
  )
113
 
114
- def regenerate(chatbot, history):
 
115
  while chatbot and chatbot[-1]["role"] == "assistant":
116
  chatbot.pop()
117
  while history and history[-1]["role"] == "assistant":
118
- print(f"discard {history[-1]}")
119
  history.pop()
120
- return predict(chatbot, history, audio_model, token2wav, args.prompt_wav, args.cache_dir)
121
 
122
  regen_btn.click(
123
- regenerate,
124
- [chatbot, history],
125
- [chatbot, history],
126
- #show_progress=True,
127
  concurrency_id="gpu_queue",
128
  )
129
 
130
- demo.queue().launch(
131
- server_port=args.server_port,
132
- server_name=args.server_name,
133
- )
134
-
135
 
136
  if __name__ == "__main__":
137
- import os
138
  from argparse import ArgumentParser
139
 
140
- from stepaudio2 import StepAudio2
141
- from token2wav import Token2wav
142
-
143
  parser = ArgumentParser()
144
- parser.add_argument("--model-path", type=str, default='Step-Audio-2-mini', help="Model path.")
145
- parser.add_argument(
146
- "--server-port", type=int, default=7860, help="Demo server port."
147
- )
148
- parser.add_argument(
149
- "--server-name", type=str, default="0.0.0.0", help="Demo server name."
150
- )
151
- parser.add_argument(
152
- "--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant."
153
- )
154
- parser.add_argument(
155
- "--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory."
156
- )
157
  args = parser.parse_args()
 
158
  os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
 
159
 
160
- audio_model = StepAudio2(args.model_path)
161
- token2wav = Token2wav("token2wav")
162
- _launch_demo(args, audio_model, token2wav)
 
1
  import os
2
  import shlex
3
  import subprocess
4
+ import threading
5
+ import tempfile
6
+ import traceback
7
+ from pathlib import Path
8
 
 
9
  os.system("pip install -r requirements.txt")
 
10
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/campplus.onnx -P token2wav")
11
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.pt -P token2wav")
12
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/flow.yaml -P token2wav")
13
  os.system("wget https://huggingface.co/stepfun-ai/Step-Audio-2-mini/resolve/main/token2wav/hift.pt -P token2wav")
14
 
15
+ # HF token passthrough
 
16
  hf_token = os.getenv("HF_TOKEN", None)
17
  os.environ["HF_TOKEN"] = hf_token
18
 
 
 
 
19
  import spaces
20
  import gradio as gr
21
 
22
+ def save_tmp_audio(audio_bytes, cache_dir):
23
+ os.makedirs(cache_dir, exist_ok=True)
24
+ with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as temp_audio:
25
+ temp_audio.write(audio_bytes)
 
26
  return temp_audio.name
27
 
28
  def add_message(chatbot, history, mic, text):
 
34
  history.append({"role": "human", "content": text})
35
  elif mic and Path(mic).exists():
36
  chatbot.append({"role": "user", "content": {"path": mic}})
37
+ history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
38
 
 
39
  return chatbot, history, None
40
 
41
  def reset_state(system_prompt):
42
  return [], [{"role": "system", "content": system_prompt}]
43
 
44
+ _AUDIO_MODEL = None
45
+ _TOKEN2WAV = None
46
+ _INIT_LOCK = threading.Lock()
47
+
48
+ def _ensure_models(model_path: str, token2wav_dir: str):
49
+ """
50
+ Create heavy, non-picklable objects *inside* the worker process exactly once.
51
+ """
52
+ global _AUDIO_MODEL, _TOKEN2WAV
53
+ if _AUDIO_MODEL is None or _TOKEN2WAV is None:
54
+ with _INIT_LOCK:
55
+ if _AUDIO_MODEL is None or _TOKEN2WAV is None:
56
+ # Import here to avoid importing before process fork
57
+ from stepaudio2 import StepAudio2
58
+ from token2wav import Token2wav
59
+ # Create non-picklable instances
60
+ _AUDIO_MODEL = StepAudio2(model_path)
61
+ _TOKEN2WAV = Token2wav(token2wav_dir)
62
+
63
+ return _AUDIO_MODEL, _TOKEN2WAV
64
+
65
  @spaces.GPU
66
+ def predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir):
67
+ """
68
+ IMPORTANT: All parameters are simple strings/lists (picklable).
69
+ Heavy objects are created inside via _ensure_models(...).
70
+ """
71
  try:
72
+ audio_model, token2wav = _ensure_models(model_path, token2wav_dir)
73
+
74
+ # Stream start marker
75
+ history.append({
76
+ "role": "assistant",
77
+ "content": [{"type": "text", "text": "<tts_start>"}],
78
+ "eot": False
79
+ })
80
+
81
+ # Your original generation call
82
+ tokens, text, audio_tokens = audio_model(
83
+ history,
84
+ max_new_tokens=4096,
85
+ temperature=0.7,
86
+ repetition_penalty=1.05,
87
+ do_sample=True
88
+ )
89
+
90
+ # Convert tokens -> wav bytes
91
+ audio_bytes = token2wav(audio_tokens, prompt_wav)
92
+
93
+ # Save to temp file for gradio Chatbot
94
+ audio_path = save_tmp_audio(audio_bytes, cache_dir)
95
  chatbot.append({"role": "assistant", "content": {"path": audio_path}})
96
+
97
+ # Finish the assistant turn
98
  history[-1]["content"].append({"type": "token", "token": tokens})
99
  history[-1]["eot"] = True
100
+
101
  except Exception:
102
  print(traceback.format_exc())
103
+ gr.Warning("Some error happened, please try again.")
104
  return chatbot, history
105
 
106
+ def _launch_demo(args):
107
  with gr.Blocks(delete_cache=(86400, 86400)) as demo:
108
+ gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
109
+
110
  with gr.Row():
111
  system_prompt = gr.Textbox(
112
  label="System Prompt",
113
+ value=(
114
+ "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n"
115
+ "你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n"
116
+ "今天是2025年8月29日,星期五\n"
117
+ "请用默认女声与用户交流。"
118
+ ),
119
  lines=2
120
  )
121
+
122
+ chatbot = gr.Chatbot(elem_id="chatbot", min_height=800, type="messages")
123
+ # Initialize history with the *string* value of the prompt
 
 
 
124
  history = gr.State([{"role": "system", "content": system_prompt.value}])
125
+
126
+ # Inputs
127
  mic = gr.Audio(type="filepath")
128
  text = gr.Textbox(placeholder="Enter message ...")
129
 
130
+ # Serializable configuration inputs (STRINGS ONLY)
131
+ model_path = gr.Textbox(value="Step-Audio-2-mini", label="Model path")
132
+ token2wav_dir = gr.Textbox(value="token2wav", label="Token2Wav directory")
133
+ prompt_wav = gr.Textbox(value="assets/default_female.wav", label="Prompt WAV path")
134
+ cache_dir = gr.Textbox(value="/tmp/stepaudio2", label="Cache directory")
135
+
136
  with gr.Row():
137
  clean_btn = gr.Button("🧹 Clear History (清除历史)")
138
  regen_btn = gr.Button("🤔️ Regenerate (重试)")
139
  submit_btn = gr.Button("🚀 Submit")
140
 
141
+ # --- event functions (now only use serializable args) ---
142
+ def on_submit(chatbot, history, mic, text, prompt_wav, cache_dir, model_path, token2wav_dir):
143
+ chatbot, history, error = add_message(chatbot, history, mic, text)
 
144
  if error:
145
+ gr.Warning(error)
 
 
 
146
  return chatbot, history, None, None
147
+ chatbot, history = predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir)
148
+ return chatbot, history, None, None
149
 
150
  submit_btn.click(
151
  fn=on_submit,
152
+ inputs=[chatbot, history, mic, text, prompt_wav, cache_dir, model_path, token2wav_dir],
153
  outputs=[chatbot, history, mic, text],
154
  concurrency_limit=4,
155
  concurrency_id="gpu_queue",
 
159
  fn=reset_state,
160
  inputs=[system_prompt],
161
  outputs=[chatbot, history],
 
162
  )
163
 
164
+ def on_regen(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir):
165
+ # drop last assistant turn so we can re-run
166
  while chatbot and chatbot[-1]["role"] == "assistant":
167
  chatbot.pop()
168
  while history and history[-1]["role"] == "assistant":
 
169
  history.pop()
170
+ return predict(chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir)
171
 
172
  regen_btn.click(
173
+ fn=on_regen,
174
+ inputs=[chatbot, history, prompt_wav, cache_dir, model_path, token2wav_dir],
175
+ outputs=[chatbot, history],
 
176
  concurrency_id="gpu_queue",
177
  )
178
 
179
+ demo.queue().launch(server_port=args.server_port, server_name=args.server_name)
 
 
 
 
180
 
181
  if __name__ == "__main__":
 
182
  from argparse import ArgumentParser
183
 
 
 
 
184
  parser = ArgumentParser()
185
+ parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
186
+ parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
187
+ parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
188
+ parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
189
+ parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
 
 
 
 
 
 
 
 
190
  args = parser.parse_args()
191
+
192
  os.environ["GRADIO_TEMP_DIR"] = args.cache_dir
193
+ os.makedirs(args.cache_dir, exist_ok=True)
194
 
195
+ # NOTE: Do NOT instantiate heavy models here.
196
+ # They will be created lazily inside predict() via _ensure_models(...).
197
+ _launch_demo(args)