reach-vb commited on
Commit
d17cd30
ยท
verified ยท
1 Parent(s): d668318
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import traceback
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import spaces # required for ZeroGPU
8
+
9
+ # ---- Your model libs (ensure these are available in the repo or pip) ----
10
+ from stepaudio2 import StepAudio2
11
+ from token2wav import Token2wav
12
+
13
+ # ------------------------- constants -------------------------
14
+ MODEL_PATH = "Step-Audio-2-mini"
15
+ PROMPT_WAV = "assets/default_female.wav"
16
+ CACHE_DIR = "/tmp/stepaudio2"
17
+
18
+ # Ensure Gradio uses a writable temp dir on Spaces
19
+ os.environ["GRADIO_TEMP_DIR"] = CACHE_DIR
20
+ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
21
+
22
+ # ------------------------- helpers -------------------------
23
+ def save_tmp_audio(audio_bytes: bytes, cache_dir: str) -> str:
24
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
25
+ with tempfile.NamedTemporaryFile(dir=cache_dir, delete=False, suffix=".wav") as f:
26
+ f.write(audio_bytes)
27
+ return f.name
28
+
29
+ def add_message(chatbot, history, mic, text):
30
+ if not mic and not text:
31
+ return chatbot, history, "Input is empty"
32
+
33
+ if text:
34
+ chatbot.append({"role": "user", "content": text})
35
+ history.append({"role": "human", "content": text})
36
+ elif mic and Path(mic).exists():
37
+ chatbot.append({"role": "user", "content": {"path": mic}})
38
+ history.append({"role": "human", "content": [{"type": "audio", "audio": mic}]})
39
+ return chatbot, history, None
40
+
41
+ def reset_state(system_prompt):
42
+ return [], [{"role": "system", "content": system_prompt}]
43
+
44
+ # ------------------------- globals -------------------------
45
+ AUDIO_MODEL = StepAudio2(MODEL_PATH) # load on CPU
46
+ TOKEN2WAV = Token2wav(f"{MODEL_PATH}/token2wav") # load on CPU
47
+
48
+ @spaces.GPU(duration=120) # GPU only during this call; no-ops outside ZeroGPU
49
+ def gpu_predict(chatbot, history):
50
+ global AUDIO_MODEL, TOKEN2WAV
51
+ try:
52
+ # Move to CUDA only when GPU is attached
53
+ try:
54
+ if hasattr(AUDIO_MODEL, "to"):
55
+ AUDIO_MODEL.to("cuda")
56
+ if hasattr(TOKEN2WAV, "to"):
57
+ TOKEN2WAV.to("cuda")
58
+ except Exception:
59
+ pass
60
+
61
+ history.append({"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"}], "eot": False})
62
+
63
+ tokens, text, audio_tokens = AUDIO_MODEL(
64
+ history,
65
+ max_new_tokens=4096,
66
+ temperature=0.7,
67
+ repetition_penalty=1.05,
68
+ do_sample=True,
69
+ )
70
+
71
+ audio_bytes = TOKEN2WAV(audio_tokens, PROMPT_WAV)
72
+ audio_path = save_tmp_audio(audio_bytes, CACHE_DIR)
73
+
74
+ chatbot.append({"role": "assistant", "content": {"path": audio_path}})
75
+ history[-1]["content"].append({"type": "token", "token": tokens})
76
+ history[-1]["eot"] = True
77
+
78
+ except Exception:
79
+ print(traceback.format_exc())
80
+ gr.Warning("Some error happened, please try again.")
81
+ return chatbot, history
82
+
83
+ def build_demo():
84
+ with gr.Blocks(delete_cache=(86400, 86400)) as demo:
85
+ gr.Markdown("<center><font size=8>Step Audio 2 Demo</center>")
86
+
87
+ with gr.Row():
88
+ system_prompt = gr.Textbox(
89
+ label="System Prompt",
90
+ value=(
91
+ "ไฝ ็š„ๅๅญ—ๅซๅšๅฐ่ทƒ๏ผŒๆ˜ฏ็”ฑ้˜ถ่ทƒๆ˜Ÿ่พฐๅ…ฌๅธ่ฎญ็ปƒๅ‡บๆฅ็š„่ฏญ้Ÿณๅคงๆจกๅž‹ใ€‚\n"
92
+ "ไฝ ๆƒ…ๆ„Ÿ็ป†่…ป๏ผŒ่ง‚ๅฏŸ่ƒฝๅŠ›ๅผบ๏ผŒๆ“…้•ฟๅˆ†ๆž็”จๆˆท็š„ๅ†…ๅฎน๏ผŒๅนถไฝœๅ‡บๅ–„่งฃไบบๆ„็š„ๅ›žๅค๏ผŒ"
93
+ "่ฏด่ฏ็š„่ฟ‡็จ‹ไธญๆ—ถๅˆปๆณจๆ„็”จๆˆท็š„ๆ„Ÿๅ—๏ผŒๅฏŒๆœ‰ๅŒ็†ๅฟƒ๏ผŒๆไพ›ๅคšๆ ท็š„ๆƒ…็ปชไปทๅ€ผใ€‚\n"
94
+ "ไปŠๅคฉๆ˜ฏ2025ๅนด8ๆœˆ29ๆ—ฅ๏ผŒๆ˜ŸๆœŸไบ”\n"
95
+ "่ฏท็”จ้ป˜่ฎคๅฅณๅฃฐไธŽ็”จๆˆทไบคๆตใ€‚"
96
+ ),
97
+ lines=2,
98
+ )
99
+
100
+ chatbot = gr.Chatbot(elem_id="chatbot", min_height=800, type="messages")
101
+ history = gr.State([{"role": "system", "content": system_prompt.value}])
102
+
103
+ mic = gr.Audio(type="filepath", label="๐ŸŽ™๏ธ Microphone input (optional)")
104
+ text = gr.Textbox(placeholder="Enter message ...", label="๐Ÿ’ฌ Text input")
105
+
106
+ with gr.Row():
107
+ clean_btn = gr.Button("๐Ÿงน Clear History (ๆธ…้™คๅކๅฒ)")
108
+ regen_btn = gr.Button("๐Ÿค”๏ธ Regenerate (้‡่ฏ•)")
109
+ submit_btn = gr.Button("๐Ÿš€ Submit")
110
+
111
+ def on_submit(chatbot, history, mic, text):
112
+ chatbot, history, error = add_message(chatbot, history, mic, text)
113
+ if error:
114
+ gr.Warning(error)
115
+ return chatbot, history, None, None
116
+ chatbot, history = gpu_predict(chatbot, history)
117
+ return chatbot, history, None, None
118
+
119
+ submit_btn.click(
120
+ fn=on_submit,
121
+ inputs=[chatbot, history, mic, text],
122
+ outputs=[chatbot, history, mic, text],
123
+ concurrency_limit=4,
124
+ concurrency_id="gpu_queue",
125
+ )
126
+
127
+ clean_btn.click(
128
+ fn=reset_state,
129
+ inputs=[system_prompt],
130
+ outputs=[chatbot, history],
131
+ )
132
+
133
+ def regenerate(chatbot, history):
134
+ while chatbot and chatbot[-1]["role"] == "assistant":
135
+ chatbot.pop()
136
+ while history and history[-1]["role"] == "assistant":
137
+ history.pop()
138
+ return gpu_predict(chatbot, history)
139
+
140
+ regen_btn.click(
141
+ regenerate,
142
+ [chatbot, history],
143
+ [chatbot, history],
144
+ concurrency_id="gpu_queue",
145
+ )
146
+ return demo
147
+
148
+ # Spaces runs this file; just build and launch with defaults (no ports/names).
149
+ if __name__ == "__main__":
150
+ demo = build_demo()
151
+ demo.queue().launch() # no args โ€” Spaces handles host/port