Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -61,7 +61,6 @@ def _get_models(model_path: str):
|
|
| 61 |
"""
|
| 62 |
global _MODEL, _TOK2WAV
|
| 63 |
if _MODEL is None or _TOK2WAV is None:
|
| 64 |
-
# Import here so the objects are constructed in the worker
|
| 65 |
from stepaudio2 import StepAudio2
|
| 66 |
from token2wav import Token2wav
|
| 67 |
_MODEL = StepAudio2(model_path)
|
|
@@ -69,7 +68,7 @@ def _get_models(model_path: str):
|
|
| 69 |
return _MODEL, _TOK2WAV
|
| 70 |
|
| 71 |
@spaces.GPU
|
| 72 |
-
def predict(chatbot, history,
|
| 73 |
"""
|
| 74 |
Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
|
| 75 |
Heavy models are created via _get_models() inside this process.
|
|
@@ -93,7 +92,7 @@ def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mi
|
|
| 93 |
print(f"predict text={text!r}")
|
| 94 |
|
| 95 |
# Convert tokens -> waveform bytes using token2wav
|
| 96 |
-
audio_bytes = token2wav(audio_tokens
|
| 97 |
|
| 98 |
# Persist to temp .wav for the UI
|
| 99 |
audio_path = save_tmp_audio(audio_bytes, cache_dir)
|
|
@@ -132,7 +131,6 @@ def _launch_demo(args):
|
|
| 132 |
type="messages",
|
| 133 |
)
|
| 134 |
|
| 135 |
-
# Initialize history with current system prompt value
|
| 136 |
history = gr.State([{"role": "system", "content": system_prompt.value}])
|
| 137 |
|
| 138 |
mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
|
|
@@ -148,10 +146,9 @@ def _launch_demo(args):
|
|
| 148 |
if error:
|
| 149 |
gr.Warning(error)
|
| 150 |
return chatbot2, history2, None, None
|
| 151 |
-
# Run GPU inference with only picklable args
|
| 152 |
chatbot2, history2 = predict(
|
| 153 |
chatbot2, history2,
|
| 154 |
-
args.
|
| 155 |
model_path=args.model_path
|
| 156 |
)
|
| 157 |
return chatbot2, history2, None, None
|
|
@@ -174,7 +171,6 @@ def _launch_demo(args):
|
|
| 174 |
)
|
| 175 |
|
| 176 |
def on_regenerate(chatbot_val, history_val):
|
| 177 |
-
# Drop last assistant turn(s) to regenerate
|
| 178 |
while chatbot_val and chatbot_val[-1]["role"] == "assistant":
|
| 179 |
chatbot_val.pop()
|
| 180 |
while history_val and history_val[-1]["role"] == "assistant":
|
|
@@ -182,7 +178,7 @@ def _launch_demo(args):
|
|
| 182 |
history_val.pop()
|
| 183 |
return predict(
|
| 184 |
chatbot_val, history_val,
|
| 185 |
-
args.
|
| 186 |
model_path=args.model_path
|
| 187 |
)
|
| 188 |
|
|
@@ -205,7 +201,6 @@ if __name__ == "__main__":
|
|
| 205 |
parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
|
| 206 |
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 207 |
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
| 208 |
-
parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
|
| 209 |
parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
|
| 210 |
args = parser.parse_args()
|
| 211 |
|
|
|
|
| 61 |
"""
|
| 62 |
global _MODEL, _TOK2WAV
|
| 63 |
if _MODEL is None or _TOK2WAV is None:
|
|
|
|
| 64 |
from stepaudio2 import StepAudio2
|
| 65 |
from token2wav import Token2wav
|
| 66 |
_MODEL = StepAudio2(model_path)
|
|
|
|
| 68 |
return _MODEL, _TOK2WAV
|
| 69 |
|
| 70 |
@spaces.GPU
|
| 71 |
+
def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
|
| 72 |
"""
|
| 73 |
Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
|
| 74 |
Heavy models are created via _get_models() inside this process.
|
|
|
|
| 92 |
print(f"predict text={text!r}")
|
| 93 |
|
| 94 |
# Convert tokens -> waveform bytes using token2wav
|
| 95 |
+
audio_bytes = token2wav(audio_tokens)
|
| 96 |
|
| 97 |
# Persist to temp .wav for the UI
|
| 98 |
audio_path = save_tmp_audio(audio_bytes, cache_dir)
|
|
|
|
| 131 |
type="messages",
|
| 132 |
)
|
| 133 |
|
|
|
|
| 134 |
history = gr.State([{"role": "system", "content": system_prompt.value}])
|
| 135 |
|
| 136 |
mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
|
|
|
|
| 146 |
if error:
|
| 147 |
gr.Warning(error)
|
| 148 |
return chatbot2, history2, None, None
|
|
|
|
| 149 |
chatbot2, history2 = predict(
|
| 150 |
chatbot2, history2,
|
| 151 |
+
args.cache_dir,
|
| 152 |
model_path=args.model_path
|
| 153 |
)
|
| 154 |
return chatbot2, history2, None, None
|
|
|
|
| 171 |
)
|
| 172 |
|
| 173 |
def on_regenerate(chatbot_val, history_val):
|
|
|
|
| 174 |
while chatbot_val and chatbot_val[-1]["role"] == "assistant":
|
| 175 |
chatbot_val.pop()
|
| 176 |
while history_val and history_val[-1]["role"] == "assistant":
|
|
|
|
| 178 |
history_val.pop()
|
| 179 |
return predict(
|
| 180 |
chatbot_val, history_val,
|
| 181 |
+
args.cache_dir,
|
| 182 |
model_path=args.model_path
|
| 183 |
)
|
| 184 |
|
|
|
|
| 201 |
parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
|
| 202 |
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 203 |
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
|
|
|
| 204 |
parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
|
| 205 |
args = parser.parse_args()
|
| 206 |
|