Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import shlex
|
| 3 |
import subprocess
|
|
@@ -22,6 +24,10 @@ if hf_token is not None:
|
|
| 22 |
import spaces
|
| 23 |
import gradio as gr
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
|
| 26 |
"""Save raw wav bytes to a temporary file and return path."""
|
| 27 |
os.makedirs(cache_dir, exist_ok=True)
|
|
@@ -51,6 +57,9 @@ def reset_state(system_prompt: str):
|
|
| 51 |
return [], [{"role": "system", "content": system_prompt}]
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
_MODEL = None
|
| 55 |
_TOK2WAV = None
|
| 56 |
|
|
@@ -61,14 +70,19 @@ def _get_models(model_path: str):
|
|
| 61 |
"""
|
| 62 |
global _MODEL, _TOK2WAV
|
| 63 |
if _MODEL is None or _TOK2WAV is None:
|
|
|
|
| 64 |
from stepaudio2 import StepAudio2
|
| 65 |
from token2wav import Token2wav
|
| 66 |
_MODEL = StepAudio2(model_path)
|
| 67 |
_TOK2WAV = Token2wav("token2wav")
|
| 68 |
return _MODEL, _TOK2WAV
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
@spaces.GPU
|
| 71 |
-
def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
|
| 72 |
"""
|
| 73 |
Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
|
| 74 |
Heavy models are created via _get_models() inside this process.
|
|
@@ -92,7 +106,7 @@ def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
|
|
| 92 |
print(f"predict text={text!r}")
|
| 93 |
|
| 94 |
# Convert tokens -> waveform bytes using token2wav
|
| 95 |
-
audio_bytes = token2wav(audio_tokens)
|
| 96 |
|
| 97 |
# Persist to temp .wav for the UI
|
| 98 |
audio_path = save_tmp_audio(audio_bytes, cache_dir)
|
|
@@ -108,6 +122,10 @@ def predict(chatbot, history, cache_dir, model_path="Step-Audio-2-mini"):
|
|
| 108 |
|
| 109 |
return chatbot, history
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def _launch_demo(args):
|
| 112 |
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
|
| 113 |
gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
|
|
@@ -131,6 +149,7 @@ def _launch_demo(args):
|
|
| 131 |
type="messages",
|
| 132 |
)
|
| 133 |
|
|
|
|
| 134 |
history = gr.State([{"role": "system", "content": system_prompt.value}])
|
| 135 |
|
| 136 |
mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
|
|
@@ -146,9 +165,10 @@ def _launch_demo(args):
|
|
| 146 |
if error:
|
| 147 |
gr.Warning(error)
|
| 148 |
return chatbot2, history2, None, None
|
|
|
|
| 149 |
chatbot2, history2 = predict(
|
| 150 |
chatbot2, history2,
|
| 151 |
-
args.cache_dir,
|
| 152 |
model_path=args.model_path
|
| 153 |
)
|
| 154 |
return chatbot2, history2, None, None
|
|
@@ -171,6 +191,7 @@ def _launch_demo(args):
|
|
| 171 |
)
|
| 172 |
|
| 173 |
def on_regenerate(chatbot_val, history_val):
|
|
|
|
| 174 |
while chatbot_val and chatbot_val[-1]["role"] == "assistant":
|
| 175 |
chatbot_val.pop()
|
| 176 |
while history_val and history_val[-1]["role"] == "assistant":
|
|
@@ -178,7 +199,7 @@ def _launch_demo(args):
|
|
| 178 |
history_val.pop()
|
| 179 |
return predict(
|
| 180 |
chatbot_val, history_val,
|
| 181 |
-
args.cache_dir,
|
| 182 |
model_path=args.model_path
|
| 183 |
)
|
| 184 |
|
|
@@ -194,6 +215,10 @@ def _launch_demo(args):
|
|
| 194 |
server_name=args.server_name,
|
| 195 |
)
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
if __name__ == "__main__":
|
| 198 |
from argparse import ArgumentParser
|
| 199 |
|
|
@@ -201,6 +226,7 @@ if __name__ == "__main__":
|
|
| 201 |
parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
|
| 202 |
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 203 |
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
|
|
|
| 204 |
parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
|
| 205 |
args = parser.parse_args()
|
| 206 |
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import shlex
|
| 5 |
import subprocess
|
|
|
|
| 24 |
import spaces
|
| 25 |
import gradio as gr
|
| 26 |
|
| 27 |
+
|
| 28 |
+
# -----------------------
|
| 29 |
+
# Utility helpers
|
| 30 |
+
# -----------------------
|
| 31 |
def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
|
| 32 |
"""Save raw wav bytes to a temporary file and return path."""
|
| 33 |
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
| 57 |
return [], [{"role": "system", "content": system_prompt}]
|
| 58 |
|
| 59 |
|
| 60 |
+
# -----------------------
|
| 61 |
+
# Lazy model loading inside the GPU worker
|
| 62 |
+
# -----------------------
|
| 63 |
_MODEL = None
|
| 64 |
_TOK2WAV = None
|
| 65 |
|
|
|
|
| 70 |
"""
|
| 71 |
global _MODEL, _TOK2WAV
|
| 72 |
if _MODEL is None or _TOK2WAV is None:
|
| 73 |
+
# Import here so the objects are constructed in the worker
|
| 74 |
from stepaudio2 import StepAudio2
|
| 75 |
from token2wav import Token2wav
|
| 76 |
_MODEL = StepAudio2(model_path)
|
| 77 |
_TOK2WAV = Token2wav("token2wav")
|
| 78 |
return _MODEL, _TOK2WAV
|
| 79 |
|
| 80 |
+
|
| 81 |
+
# -----------------------
|
| 82 |
+
# Inference
|
| 83 |
+
# -----------------------
|
| 84 |
@spaces.GPU
|
| 85 |
+
def predict(chatbot, history, prompt_wav, cache_dir, model_path="Step-Audio-2-mini"):
|
| 86 |
"""
|
| 87 |
Run generation on GPU worker. All args must be picklable (strings, lists, dicts).
|
| 88 |
Heavy models are created via _get_models() inside this process.
|
|
|
|
| 106 |
print(f"predict text={text!r}")
|
| 107 |
|
| 108 |
# Convert tokens -> waveform bytes using token2wav
|
| 109 |
+
audio_bytes = token2wav(audio_tokens, prompt_wav)
|
| 110 |
|
| 111 |
# Persist to temp .wav for the UI
|
| 112 |
audio_path = save_tmp_audio(audio_bytes, cache_dir)
|
|
|
|
| 122 |
|
| 123 |
return chatbot, history
|
| 124 |
|
| 125 |
+
|
| 126 |
+
# -----------------------
|
| 127 |
+
# UI
|
| 128 |
+
# -----------------------
|
| 129 |
def _launch_demo(args):
|
| 130 |
with gr.Blocks(delete_cache=(86400, 86400)) as demo:
|
| 131 |
gr.Markdown("""<center><font size=8>Step Audio 2 Demo</font></center>""")
|
|
|
|
| 149 |
type="messages",
|
| 150 |
)
|
| 151 |
|
| 152 |
+
# Initialize history with current system prompt value
|
| 153 |
history = gr.State([{"role": "system", "content": system_prompt.value}])
|
| 154 |
|
| 155 |
mic = gr.Audio(type="filepath", label="🎤 Speak (optional)")
|
|
|
|
| 165 |
if error:
|
| 166 |
gr.Warning(error)
|
| 167 |
return chatbot2, history2, None, None
|
| 168 |
+
# Run GPU inference with only picklable args
|
| 169 |
chatbot2, history2 = predict(
|
| 170 |
chatbot2, history2,
|
| 171 |
+
args.prompt_wav, args.cache_dir,
|
| 172 |
model_path=args.model_path
|
| 173 |
)
|
| 174 |
return chatbot2, history2, None, None
|
|
|
|
| 191 |
)
|
| 192 |
|
| 193 |
def on_regenerate(chatbot_val, history_val):
|
| 194 |
+
# Drop last assistant turn(s) to regenerate
|
| 195 |
while chatbot_val and chatbot_val[-1]["role"] == "assistant":
|
| 196 |
chatbot_val.pop()
|
| 197 |
while history_val and history_val[-1]["role"] == "assistant":
|
|
|
|
| 199 |
history_val.pop()
|
| 200 |
return predict(
|
| 201 |
chatbot_val, history_val,
|
| 202 |
+
args.prompt_wav, args.cache_dir,
|
| 203 |
model_path=args.model_path
|
| 204 |
)
|
| 205 |
|
|
|
|
| 215 |
server_name=args.server_name,
|
| 216 |
)
|
| 217 |
|
| 218 |
+
|
| 219 |
+
# -----------------------
|
| 220 |
+
# Entrypoint
|
| 221 |
+
# -----------------------
|
| 222 |
if __name__ == "__main__":
|
| 223 |
from argparse import ArgumentParser
|
| 224 |
|
|
|
|
| 226 |
parser.add_argument("--model-path", type=str, default="Step-Audio-2-mini", help="Model path.")
|
| 227 |
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 228 |
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
| 229 |
+
parser.add_argument("--prompt-wav", type=str, default="assets/default_female.wav", help="Prompt wave for the assistant.")
|
| 230 |
parser.add_argument("--cache-dir", type=str, default="/tmp/stepaudio2", help="Cache directory.")
|
| 231 |
args = parser.parse_args()
|
| 232 |
|