Steveeeeeeen HF Staff commited on
Commit
876f2fc
·
verified ·
1 Parent(s): b23448b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -4
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import shlex
3
  import subprocess
@@ -22,6 +24,10 @@ if hf_token is not None:
22
  import spaces
23
  import gradio as gr
24
 
 
 
 
 
25
  def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
26
  """Save raw wav bytes to a temporary file and return path."""
27
  os.makedirs(cache_dir, exist_ok=True)
@@ -51,6 +57,9 @@ def reset_state(system_prompt: str):
51
  return [], [{"role": "system", "content": system_prompt}]
52
 
53
 
 
 
 
54
  _MODEL = None
55
  _TOK2WAV = None
56
 
@@ -61,14 +70,19 @@ def _get_models(model_path: str):
61
  """
62
  global _MODEL, _TOK2WAV
63
  if _MODEL is None or _TOK2WAV is None:
 
64
  from stepaudio2 import StepAudio2
65
  from token2wav import Token2wav
66
  _MODEL = StepAudio2(model_path)
67
  _TOK2WAV = Token2wav("token2wav")
68
  return _MODEL, _TOK2WAV
69
 
 
 
 
 
70
  @spaces.GPU
71
- def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
72
  """
73
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
74
  Heavy models are created via _get_models() inside this process.
@@ -92,7 +106,7 @@ def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
92
  print(f"predict text={text!r}")
93
 
94
  # Convert tokens -> waveform bytes using token2wav
95
- audio_bytes = token2wav(audio_tokens)
96
 
97
  # Persist to temp .wav for the UI
98
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
@@ -108,6 +122,10 @@ def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
108
 
109
  return chatbot, history
110
 
 
 
 
 
111
  def _launch_demo(args):
112
  with gr.Blocks(delete_cache=(86400, 86400)) as demo:
113
  gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
@@ -131,6 +149,7 @@ def _launch_demo(args):
131
  type="messages",
132
  )
133
 
 
134
  history = gr.State([{"role": "system", "content": system_prompt.value}])
135
 
136
  mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
@@ -146,9 +165,10 @@ def _launch_demo(args):
146
  if error:
147
  gr.Warning(error)
148
  return chatbot2, history2, None, None
 
149
  chatbot2, history2 = predict(
150
  chatbot2, history2,
151
- args.cache_dir,
152
  model_path=args.model_path
153
  )
154
  return chatbot2, history2, None, None
@@ -171,6 +191,7 @@ def _launch_demo(args):
171
  )
172
 
173
  def on_regenerate(chatbot_val, history_val):
 
174
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
175
  chatbot_val.pop()
176
  while history_val and history_val[-1]["role"] == "assistant":
@@ -178,7 +199,7 @@ def _launch_demo(args):
178
  history_val.pop()
179
  return predict(
180
  chatbot_val, history_val,
181
- args.cache_dir,
182
  model_path=args.model_path
183
  )
184
 
@@ -194,6 +215,10 @@ def _launch_demo(args):
194
  server_name=args.server_name,
195
  )
196
 
 
 
 
 
197
  if __name__ == "__main__":
198
  from argparse import ArgumentParser
199
 
@@ -201,6 +226,7 @@ if __name__ == "__main__":
201
  parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
202
  parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
203
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
 
204
  parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
205
  args = parser.parse_args()
206
 
 
1
+ # app.py
2
+
3
  import os
4
  import shlex
5
  import subprocess
 
24
  import spaces
25
  import gradio as gr
26
 
27
+
28
+ # -----------------------
29
+ # Utility helpers
30
+ # -----------------------
31
  def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
32
  """Save raw wav bytes to a temporary file and return path."""
33
  os.makedirs(cache_dir, exist_ok=True)
 
57
  return [], [{"role": "system", "content": system_prompt}]
58
 
59
 
60
+ # -----------------------
61
+ # Lazy model loading inside the GPU worker
62
+ # -----------------------
63
  _MODEL = None
64
  _TOK2WAV = None
65
 
 
70
  """
71
  global _MODEL, _TOK2WAV
72
  if _MODEL is None or _TOK2WAV is None:
73
+ # Import here so the objects are constructed in the worker
74
  from stepaudio2 import StepAudio2
75
  from token2wav import Token2wav
76
  _MODEL = StepAudio2(model_path)
77
  _TOK2WAV = Token2wav("token2wav")
78
  return _MODEL, _TOK2WAV
79
 
80
+
81
+ # -----------------------
82
+ # Inference
83
+ # -----------------------
84
  @spaces.GPU
85
+ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mini"):
86
  """
87
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
88
  Heavy models are created via _get_models() inside this process.
 
106
  print(f"predict text={text!r}")
107
 
108
  # Convert tokens -> waveform bytes using token2wav
109
+ audio_bytes = token2wav(audio_tokens, prompt_wav)
110
 
111
  # Persist to temp .wav for the UI
112
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
 
122
 
123
  return chatbot, history
124
 
125
+
126
+ # -----------------------
127
+ # UI
128
+ # -----------------------
129
  def _launch_demo(args):
130
  with gr.Blocks(delete_cache=(86400, 86400)) as demo:
131
  gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
 
149
  type="messages",
150
  )
151
 
152
+ # Initialize history with current system prompt value
153
  history = gr.State([{"role": "system", "content": system_prompt.value}])
154
 
155
  mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
 
165
  if error:
166
  gr.Warning(error)
167
  return chatbot2, history2, None, None
168
+ # Run GPU inference with only picklable args
169
  chatbot2, history2 = predict(
170
  chatbot2, history2,
171
+ args.prompt_wav, args.cache_dir,
172
  model_path=args.model_path
173
  )
174
  return chatbot2, history2, None, None
 
191
  )
192
 
193
  def on_regenerate(chatbot_val, history_val):
194
+ # Drop last assistant turn(s) to regenerate
195
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
196
  chatbot_val.pop()
197
  while history_val and history_val[-1]["role"] == "assistant":
 
199
  history_val.pop()
200
  return predict(
201
  chatbot_val, history_val,
202
+ args.prompt_wav, args.cache_dir,
203
  model_path=args.model_path
204
  )
205
 
 
215
  server_name=args.server_name,
216
  )
217
 
218
+
219
+ # -----------------------
220
+ # Entrypoint
221
+ # -----------------------
222
  if __name__ == "__main__":
223
  from argparse import ArgumentParser
224
 
 
226
  parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
227
  parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
228
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
229
+ parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
230
  parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
231
  args = parser.parse_args()
232