Steveeeeeeen HF Staff commited on
Commit
b23448b
·
verified ·
1 Parent(s): a9a5df1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -61,7 +61,6 @@ def _get_models(model_path: str):
61
  """
62
  global _MODEL, _TOK2WAV
63
  if _MODEL is None or _TOK2WAV is None:
64
- # Import here so the objects are constructed in the worker
65
  from stepaudio2 import StepAudio2
66
  from token2wav import Token2wav
67
  _MODEL = StepAudio2(model_path)
@@ -69,7 +68,7 @@ def _get_models(model_path: str):
69
  return _MODEL, _TOK2WAV
70
 
71
  @spaces.GPU
72
- def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mini"):
73
  """
74
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
75
  Heavy models are created via _get_models() inside this process.
@@ -93,7 +92,7 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi
93
  print(f"predict text={text!r}")
94
 
95
  # Convert tokens -> waveform bytes using token2wav
96
- audio_bytes = token2wav(audio_tokens, prompt_wav)
97
 
98
  # Persist to temp .wav for the UI
99
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
@@ -132,7 +131,6 @@ def _launch_demo(args):
132
  type="messages",
133
  )
134
 
135
- # Initialize history with current system prompt value
136
  history = gr.State([{"role": "system", "content": system_prompt.value}])
137
 
138
  mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
@@ -148,10 +146,9 @@ def _launch_demo(args):
148
  if error:
149
  gr.Warning(error)
150
  return chatbot2, history2, None, None
151
- # Run GPU inference with only picklable args
152
  chatbot2, history2 = predict(
153
  chatbot2, history2,
154
- args.prompt_wav, args.cache_dir,
155
  model_path=args.model_path
156
  )
157
  return chatbot2, history2, None, None
@@ -174,7 +171,6 @@ def _launch_demo(args):
174
  )
175
 
176
  def on_regenerate(chatbot_val, history_val):
177
- # Drop last assistant turn(s) to regenerate
178
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
179
  chatbot_val.pop()
180
  while history_val and history_val[-1]["role"] == "assistant":
@@ -182,7 +178,7 @@ def _launch_demo(args):
182
  history_val.pop()
183
  return predict(
184
  chatbot_val, history_val,
185
- args.prompt_wav, args.cache_dir,
186
  model_path=args.model_path
187
  )
188
 
@@ -205,7 +201,6 @@ if __name__ == "__main__":
205
  parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
206
  parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
207
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
208
- parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
209
  parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
210
  args = parser.parse_args()
211
 
 
61
  """
62
  global _MODEL, _TOK2WAV
63
  if _MODEL is None or _TOK2WAV is None:
 
64
  from stepaudio2 import StepAudio2
65
  from token2wav import Token2wav
66
  _MODEL = StepAudio2(model_path)
 
68
  return _MODEL, _TOK2WAV
69
 
70
  @spaces.GPU
71
+ def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
72
  """
73
  Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
74
  Heavy models are created via _get_models() inside this process.
 
92
  print(f"predict text={text!r}")
93
 
94
  # Convert tokens -> waveform bytes using token2wav
95
+ audio_bytes = token2wav(audio_tokens)
96
 
97
  # Persist to temp .wav for the UI
98
  audio_path = save_tmp_audio(audio_bytes, cache_dir)
 
131
  type="messages",
132
  )
133
 
 
134
  history = gr.State([{"role": "system", "content": system_prompt.value}])
135
 
136
  mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
 
146
  if error:
147
  gr.Warning(error)
148
  return chatbot2, history2, None, None
 
149
  chatbot2, history2 = predict(
150
  chatbot2, history2,
151
+ args.cache_dir,
152
  model_path=args.model_path
153
  )
154
  return chatbot2, history2, None, None
 
171
  )
172
 
173
  def on_regenerate(chatbot_val, history_val):
 
174
  while chatbot_val and chatbot_val[-1]["role"] == "assistant":
175
  chatbot_val.pop()
176
  while history_val and history_val[-1]["role"] == "assistant":
 
178
  history_val.pop()
179
  return predict(
180
  chatbot_val, history_val,
181
+ args.cache_dir,
182
  model_path=args.model_path
183
  )
184
 
 
201
  parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
202
  parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
203
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
 
204
  parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
205
  args = parser.parse_args()
206