zhenjiangjie commited on
Commit
3f8009c
·
1 Parent(s): 39b1e03
Files changed (4) hide show
  1. Dockerfile +6 -6
  2. app.py +231 -357
  3. requirements.txt +3 -1
  4. start_services.sh +81 -0
Dockerfile CHANGED
@@ -11,16 +11,16 @@ COPY --chown=user requirements.txt .
11
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
 
13
  # 3. 然后复制所有应用文件
14
- COPY --chown=user start_gradio.sh app.py .
15
 
16
- # 4. 如果需要,设置脚本权限
17
- RUN chmod +x start_gradio.sh
18
 
19
- # 暴露 Gradio 端口
20
- EXPOSE 7860
21
 
22
  # 设置环境变量
23
  ENV VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
24
 
25
  # 启动脚本
26
- CMD ["./start_gradio.sh"]
 
11
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
 
13
  # 3. 然后复制所有应用文件
14
+ COPY --chown=user app.py start_services.sh ./
15
 
16
+ # 4. 设置脚本权限
17
+ RUN chmod +x start_services.sh
18
 
19
+ # 暴露端口
20
+ EXPOSE 7860 9999
21
 
22
  # 设置环境变量
23
  ENV VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
24
 
25
  # 启动脚本
26
+ CMD ["./start_services.sh"]
app.py CHANGED
@@ -1,405 +1,279 @@
1
  #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
  """
4
- Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理
5
  """
6
 
7
  import base64
 
8
  import os
9
- import sys
10
- import threading
11
- import time
12
- import traceback
13
- from typing import Optional, Tuple
14
 
15
  import gradio as gr
 
16
 
17
- # 检查命令行参数,在导入 vllm 之前确定是否启用
18
- # 这样可以在没有安装 vllm 的情况下运行界面预览
19
- if "--no-vllm" in sys.argv:
20
- os.environ["ENABLE_VLLM"] = "false"
21
 
22
- # 检查是否启用 vLLM 模式
23
- ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes")
24
-
25
- if ENABLE_VLLM:
26
- try:
27
- from vllm import LLM, SamplingParams
28
- except ImportError as err:
29
- print("[WARNING] 无法导入 vllm,自动切换到界面预览模式")
30
- print(f"[DETAIL] ImportError: {err}")
31
- traceback.print_exc()
32
- print("[INFO] 如需使用 vLLM,请确认容器环境已正确安装并可导入 vllm")
33
- ENABLE_VLLM = False
34
- LLM = None
35
- SamplingParams = None
36
- else:
37
- LLM = None
38
- SamplingParams = None
39
- print("[INFO] 运行在界面预览模式,不加载 vLLM")
40
-
41
- # 默认配置,可通过环境变量或 CLI 覆盖
42
- DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think")
43
- DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID)
44
- DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4"))
45
- DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
46
- DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9"))
47
- DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2")
48
- DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think")
49
-
50
- _llm: Optional[LLM] = None
51
- _llm_lock = threading.Lock()
52
- LLM_ARGS = {
53
- "model": DEFAULT_MODEL_PATH,
54
- "trust_remote_code": True,
55
- "tensor_parallel_size": DEFAULT_TP,
56
- "tokenizer_mode": DEFAULT_TOKENIZER_MODE,
57
- "max_model_len": DEFAULT_MAX_MODEL_LEN,
58
- "served_model_name": DEFAULT_SERVED_NAME,
59
- "gpu_memory_utilization": DEFAULT_GPU_UTIL,
60
- }
61
-
62
-
63
- def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]:
64
- """将音频文件编码为 base64"""
65
- if audio_path is None:
66
  return None
67
-
68
  try:
69
- with open(audio_path, "rb") as audio_file:
70
- audio_data = audio_file.read()
71
- audio_base64 = base64.b64encode(audio_data).decode('utf-8')
72
- # 尝试从文件扩展名推断格式
73
- ext = os.path.splitext(audio_path)[1].lower().lstrip('.')
74
- if not ext:
75
- ext = "wav" # 默认格式
76
- return {
77
- "data": audio_base64,
78
- "format": ext
79
- }
80
  except Exception as e:
81
- print(f"Error encoding audio: {e}")
82
  return None
83
 
84
-
85
- def format_messages(
86
- system_prompt: str,
87
- chat_history: list,
88
- user_text: str,
89
- audio_file: Optional[str]
90
- ) -> list:
91
- """格式化消息为 OpenAI API 格式"""
92
  messages = []
93
-
94
- # 添加 system prompt
95
- if system_prompt and system_prompt.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  messages.append({
97
- "role": "system",
98
- "content": system_prompt.strip()
 
 
 
 
 
 
 
 
 
 
 
 
99
  })
100
-
101
- # 添加历史对话
102
- for human, assistant in chat_history:
103
- if human:
104
- messages.append({"role": "user", "content": human})
105
- if assistant:
106
- messages.append({"role": "assistant", "content": assistant})
107
-
108
- # 添加���前用户输入
109
- content_parts = []
110
-
111
- # 添加文本输入
112
- if user_text and user_text.strip():
113
- content_parts.append({
114
- "type": "text",
115
- "text": user_text.strip()
116
  })
117
-
118
- # 添加音频输入
119
- if audio_file:
120
- audio_data = encode_audio_to_base64(audio_file)
121
- if audio_data:
122
- content_parts.append({
123
- "type": "input_audio",
124
- "input_audio": audio_data
125
- })
126
-
127
- if content_parts:
128
- # 如果只有一个文本部分,直接使用字符串
129
- if len(content_parts) == 1 and content_parts[0]["type"] == "text":
130
- messages.append({
131
- "role": "user",
132
- "content": content_parts[0]["text"]
133
- })
134
- else:
135
- messages.append({
136
- "role": "user",
137
- "content": content_parts
138
- })
139
-
140
  return messages
141
 
 
 
 
 
 
142
 
143
- def chat_predict(
144
- system_prompt: str,
145
- user_text: str,
146
- audio_file: Optional[str],
147
- chat_history: list,
148
- max_tokens: int,
149
- temperature: float,
150
- top_p: float
151
- ) -> Tuple[list, str]:
152
- """调用本地 vLLM LLM 完成推理"""
153
  if not user_text and not audio_file:
154
- return chat_history, " 请提供文本或音频输入"
155
-
156
- # 如果是预览模式,返回模拟响应
157
- if not ENABLE_VLLM:
158
- user_display = user_text if user_text else "[音频输入]"
159
- mock_response = f"这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}"
160
- chat_history.append((user_display, mock_response))
161
- return chat_history, ""
162
-
163
- messages = format_messages(system_prompt, chat_history, user_text, audio_file)
 
 
 
 
 
 
 
 
 
 
 
 
164
  if not messages:
165
- return chat_history, " 无有效输入"
166
-
 
 
 
 
 
 
167
  try:
168
- llm = _get_llm()
169
- sampling_params = SamplingParams(
170
- max_tokens=max_tokens,
171
- temperature=temperature,
172
- top_p=top_p,
173
- )
174
- start_time = time.time()
175
- outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False)
176
- latency = time.time() - start_time
177
-
178
- if not outputs or not outputs[0].outputs:
179
- return chat_history, "⚠ 模型未返回结果"
180
-
181
- assistant_message = outputs[0].outputs[0].text
182
- user_display = user_text if user_text else "[音频输入]"
183
- chat_history.append((user_display, assistant_message))
184
- return chat_history, ""
185
- except Exception as e:
186
- import traceback
187
- traceback.print_exc()
188
- return chat_history, ""
189
 
 
 
 
 
 
 
 
 
 
190
 
191
- def _get_llm() -> LLM:
192
- """单例方式初始化 LLM"""
193
- if not ENABLE_VLLM:
194
- raise RuntimeError("vLLM 未启用,无法加载模型")
195
-
196
- global _llm
197
- if _llm is not None:
198
- return _llm
199
-
200
- with _llm_lock:
201
- if _llm is not None:
202
- return _llm
203
- print(f"[LLM] 初始化中,参数: {LLM_ARGS}")
204
- _llm = LLM(**LLM_ARGS)
205
- return _llm
206
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- def _set_llm_args(**kwargs) -> None:
209
- """更新 LLM 初始化参数"""
210
- global LLM_ARGS, _llm
211
- LLM_ARGS = kwargs
212
- _llm = None # 确保使用新配置重新加载
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
 
216
 
 
 
 
 
 
 
 
 
217
 
218
- # 构建 Gradio 界面
219
- with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo:
220
- gr.Markdown(
221
- """
222
- # Step Audio R1 Demo
223
- """
224
- )
225
-
226
  with gr.Row():
227
- # 左侧:参数配置
228
  with gr.Column(scale=1):
229
- gr.Markdown("### 配置")
230
-
231
- system_prompt = gr.Textbox(
232
- label="System Prompt",
233
- placeholder="输入系统提示词...",
234
- lines=4,
235
- value="You are an expert in audio analysis, please analyze the audio content and answer the questions accurately"
236
- )
237
-
238
- with gr.Row():
239
- max_tokens = gr.Slider(
240
- label="Max Tokens",
241
- minimum=1,
242
- maximum=16384,
243
- value=8192,
244
- step=1
245
- )
246
-
247
- with gr.Row():
248
- temperature = gr.Slider(
249
- label="Temperature",
250
- minimum=0.0,
251
- maximum=2.0,
252
- value=0.7,
253
- step=0.1
254
  )
255
-
256
- top_p = gr.Slider(
257
- label="Top P",
258
- minimum=0.0,
259
- maximum=1.0,
260
- value=0.9,
261
- step=0.05
262
- )
263
-
264
- # 右侧:对话和输入
265
- with gr.Column(scale=1):
266
- gr.Markdown("### 对话")
267
- chatbot = gr.Chatbot(
268
- label="聊天历史",
269
- height=400,
270
- show_copy_button=True,
271
- type="messages"
272
- )
273
-
274
- user_text = gr.Textbox(
275
- label="文本输入",
276
- placeholder="输入您的消息...",
277
- lines=2
278
- )
279
-
280
- audio_file = gr.Audio(
281
- label="音频输入",
282
- type="filepath",
283
- sources=["microphone", "upload"]
284
- )
285
-
286
  with gr.Row():
287
- submit_btn = gr.Button("提交", variant="primary", size="lg")
288
- clear_btn = gr.Button("清空", variant="secondary")
289
-
290
- status_text = gr.Textbox(label="状态", interactive=False, visible=False)
291
-
292
- # 事件绑定
293
  submit_btn.click(
294
- fn=chat_predict,
295
- inputs=[
296
- system_prompt,
297
- user_text,
298
- audio_file,
299
- chatbot,
300
- max_tokens,
301
- temperature,
302
- top_p
303
- ],
304
- outputs=[chatbot, status_text]
305
  )
306
-
307
  clear_btn.click(
308
  fn=lambda: ([], "", None),
309
  outputs=[chatbot, user_text, audio_file]
310
  )
311
 
312
-
313
  if __name__ == "__main__":
314
  import argparse
315
-
316
- parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface")
317
- parser.add_argument(
318
- "--host",
319
- type=str,
320
- default="0.0.0.0",
321
- help="服务器主机地址"
322
- )
323
- parser.add_argument(
324
- "--port",
325
- type=int,
326
- default=7860,
327
- help="服务器端口"
328
- )
329
- parser.add_argument(
330
- "--model",
331
- type=str,
332
- default=DEFAULT_MODEL_PATH,
333
- help="模型名称或本地路径"
334
- )
335
- parser.add_argument(
336
- "--tensor-parallel-size",
337
- type=int,
338
- default=DEFAULT_TP,
339
- help="张量并行数量"
340
- )
341
- parser.add_argument(
342
- "--max-model-len",
343
- type=int,
344
- default=DEFAULT_MAX_MODEL_LEN,
345
- help="最大上下文长度"
346
- )
347
- parser.add_argument(
348
- "--gpu-memory-utilization",
349
- type=float,
350
- default=DEFAULT_GPU_UTIL,
351
- help="GPU 显存利用率"
352
- )
353
- parser.add_argument(
354
- "--tokenizer-mode",
355
- type=str,
356
- default=DEFAULT_TOKENIZER_MODE,
357
- help="tokenizer 模式"
358
- )
359
- parser.add_argument(
360
- "--served-model-name",
361
- type=str,
362
- default=DEFAULT_SERVED_NAME,
363
- help="对外暴露的模型名称"
364
- )
365
- parser.add_argument(
366
- "--no-vllm",
367
- action="store_true",
368
- help="禁用 vLLM,仅启动界面预览模式"
369
- )
370
-
371
  args = parser.parse_args()
372
-
373
- # --no-vllm 参数已在文件开头处理,这里只是提示
374
- if args.no_vllm and not ENABLE_VLLM:
375
- print("[INFO] 已禁用 vLLM,运行在界面预览模式")
376
-
377
- _set_llm_args(
378
- model=args.model,
379
- trust_remote_code=True,
380
- tensor_parallel_size=args.tensor_parallel_size,
381
- tokenizer_mode=args.tokenizer_mode,
382
- max_model_len=args.max_model_len,
383
- served_model_name=args.served_model_name,
384
- gpu_memory_utilization=args.gpu_memory_utilization,
385
- )
386
-
387
- print("==========================================")
388
- print("Step Audio 2 Gradio Chat")
389
- if ENABLE_VLLM:
390
- print(f"模式: vLLM 推理模式")
391
- print(f"模型: {args.model}")
392
- print(f"Tensor Parallel Size: {args.tensor_parallel_size}")
393
- print(f"Max Model Len: {args.max_model_len}")
394
- print(f"Tokenizer Mode: {args.tokenizer_mode}")
395
- print(f"Served Model Name: {args.served_model_name}")
396
- else:
397
- print(f"模式: 界面预览模式(无 vLLM)")
398
- print(f"Gradio 地址: http://{args.host}:{args.port}")
399
- print("==========================================")
400
-
401
- demo.queue().launch(
402
- server_name=args.host,
403
- server_port=args.port,
404
- share=False
405
- )
 
1
  #!/usr/bin/env python3
 
2
  """
3
+ Step Audio R1 vLLM Gradio Interface
4
  """
5
 
6
  import base64
7
+ import json
8
  import os
 
 
 
 
 
9
 
10
  import gradio as gr
11
+ import httpx
12
 
13
+ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
 
 
15
 
16
+ def encode_audio(audio_path):
17
+ """编码音频为base64"""
18
+ if not audio_path or not os.path.exists(audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return None
 
20
  try:
21
+ with open(audio_path, "rb") as f:
22
+ return base64.b64encode(f.read()).decode()
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
+ print(f"[DEBUG] Audio error: {e}")
25
  return None
26
 
27
+ def format_messages(system, history, user_text, audio_data=None, audio_format="wav"):
28
+ """Format message list"""
 
 
 
 
 
 
29
  messages = []
30
+ if system:
31
+ messages.append({"role": "system", "content": system})
32
+
33
+ if not history:
34
+ history = []
35
+
36
+ # 处理历史记录
37
+ for item in history:
38
+ # 支持 list of dicts 格式
39
+ if isinstance(item, dict) and "role" in item and "content" in item:
40
+ messages.append(item)
41
+ # 支持 Gradio ChatMessage 对象
42
+ elif hasattr(item, "role") and hasattr(item, "content"):
43
+ messages.append({"role": item.role, "content": item.content})
44
+
45
+ # 添加当前用户消息
46
+ if user_text and audio_data:
47
  messages.append({
48
+ "role": "user",
49
+ "content": [
50
+ {
51
+ "type": "input_audio",
52
+ "input_audio": {
53
+ "data": audio_data,
54
+ "format": audio_format
55
+ }
56
+ },
57
+ {
58
+ "type": "text",
59
+ "text": user_text
60
+ }
61
+ ]
62
  })
63
+ elif user_text:
64
+ messages.append({"role": "user", "content": user_text})
65
+ elif audio_data:
66
+ messages.append({
67
+ "role": "user",
68
+ "content": [
69
+ {
70
+ "type": "input_audio",
71
+ "input_audio": {
72
+ "data": audio_data,
73
+ "format": audio_format
74
+ }
75
+ }
76
+ ]
 
 
77
  })
78
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  return messages
80
 
81
+ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, model_name=None):
82
+ """Chat function"""
83
+ # If model is not specified, use global configuration
84
+ if model_name is None:
85
+ model_name = MODEL_NAME
86
 
 
 
 
 
 
 
 
 
 
 
87
  if not user_text and not audio_file:
88
+ return history or [], "Please enter text or upload audio"
89
+
90
+ # Ensure history is a list and formatted correctly
91
+ history = history or []
92
+ clean_history = []
93
+ for item in history:
94
+ if isinstance(item, dict) and 'role' in item and 'content' in item:
95
+ clean_history.append(item)
96
+ elif hasattr(item, "role") and hasattr(item, "content"):
97
+ # Keep ChatMessage object
98
+ clean_history.append(item)
99
+ history = clean_history
100
+
101
+ # Process audio
102
+ audio_data = None
103
+ audio_format = "wav"
104
+ if audio_file:
105
+ audio_data = encode_audio(audio_file)
106
+ if audio_file.lower().endswith(".mp3"):
107
+ audio_format = "mp3"
108
+
109
+ messages = format_messages(system_prompt, history, user_text, audio_data, audio_format)
110
  if not messages:
111
+ return history or [], "Invalid input"
112
+
113
+ # Debug: Print message format
114
+ print(f"[DEBUG] Messages to API: {json.dumps(messages, ensure_ascii=False, indent=2)}")
115
+ print(f"[DEBUG] Messages type: {type(messages)}")
116
+ for i, msg in enumerate(messages):
117
+ print(f"[DEBUG] Message {i}: {type(msg)} - {msg}")
118
+
119
  try:
120
+ with httpx.Client(base_url=API_BASE_URL, timeout=120) as client:
121
+ response = client.post("/chat/completions", json={
122
+ "model": model_name,
123
+ "messages": messages,
124
+ "max_tokens": max_tokens,
125
+ "temperature": temperature,
126
+ "top_p": top_p,
127
+ "stream": True
128
+ })
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ if response.status_code != 200:
131
+ error_msg = f"❌ API Error {response.status_code}"
132
+ if response.status_code == 404:
133
+ error_msg += " - vLLM service not ready"
134
+ elif response.status_code == 400:
135
+ error_msg += " - Bad request"
136
+ elif response.status_code == 500:
137
+ error_msg += " - Model error"
138
+ return history, error_msg
139
 
140
+ # Process streaming response
141
+ content_parts = []
142
+ for line in response.iter_lines():
143
+ if not line:
144
+ continue
145
+ # Ensure line is string format
146
+ if isinstance(line, bytes):
147
+ line = line.decode('utf-8')
148
+ else:
149
+ line = str(line)
 
 
 
 
 
150
 
151
+ if line.startswith('data: '):
152
+ data_str = line[6:]
153
+ if data_str.strip() == '[DONE]':
154
+ break
155
+ try:
156
+ data = json.loads(data_str)
157
+ if 'choices' in data and len(data['choices']) > 0:
158
+ delta = data['choices'][0].get('delta', {})
159
+ if 'content' in delta:
160
+ content_parts.append(delta['content'])
161
+ except json.JSONDecodeError:
162
+ continue
163
 
164
+ full_content = ''.join(content_parts)
 
 
 
 
165
 
166
+ # Update history - only add when no error
167
+ history = history or []
168
+
169
+ # Add user message
170
+ if audio_file:
171
+ # If audio exists, show audio file and text (if any)
172
+ # Gradio Chatbot supports tuple (file_path,) to show file
173
+ # But in messages format, we need to construct proper content
174
+ # Here we use tuple format to let Gradio render audio player, or use HTML
175
+ # Simpler way: if multimodal, add messages separately
176
+
177
+ # 1. Add audio message
178
+ history.append({"role": "user", "content": gr.Audio(audio_file)})
179
+
180
+ # 2. If text exists, add text message
181
+ if user_text:
182
+ history.append({"role": "user", "content": user_text})
183
+ else:
184
+ # Text only
185
+ history.append({"role": "user", "content": user_text})
186
 
187
+ # Split think and content
188
+ if "</think>" in full_content:
189
+ parts = full_content.split("</think>", 1)
190
+ think_content = parts[0].strip()
191
+ response_content = parts[1].strip()
192
+
193
+ # Remove possible start tag
194
+ if think_content.startswith("<think>"):
195
+ think_content = think_content[len("<think>"):].strip()
196
+
197
+ # Add thinking process message (use ChatMessage and metadata)
198
+ if think_content:
199
+ history.append(gr.ChatMessage(
200
+ role="assistant",
201
+ content=think_content,
202
+ metadata={"title": "⏳ Thinking Process"}
203
+ ))
204
+
205
+ # Add formal response message
206
+ if response_content:
207
+ history.append({"role": "assistant", "content": response_content})
208
+ else:
209
+ # No think tag, add full response directly
210
+ assistant_text = full_content.strip()
211
+ if assistant_text:
212
+ history.append({"role": "assistant", "content": assistant_text})
213
 
214
+ return history, ""
215
 
216
+ except httpx.ConnectError:
217
+ return history, "❌ Cannot connect to vLLM API"
218
+ except Exception as e:
219
+ return history, f"❌ Error: {str(e)}"
220
+
221
+ # Gradio Interface
222
+ with gr.Blocks(title="Step Audio R1") as demo:
223
+ gr.Markdown("# Step Audio R1 Chat")
224
 
 
 
 
 
 
 
 
 
225
  with gr.Row():
226
+ # Left Configuration
227
  with gr.Column(scale=1):
228
+ with gr.Accordion("Configuration", open=True):
229
+ system_prompt = gr.Textbox(
230
+ label="System Prompt",
231
+ lines=2,
232
+ value="You are an audio analysis expert"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
234
+ max_tokens = gr.Slider(1, 8192, value=1024, label="Max Tokens")
235
+ temperature = gr.Slider(0.0, 2.0, value=0.7, label="Temperature")
236
+ top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
237
+
238
+ status = gr.Textbox(label="Status", interactive=False)
239
+
240
+ # Right Chat
241
+ with gr.Column(scale=2):
242
+ chatbot = gr.Chatbot(label="Chat History", height=450)
243
+ user_text = gr.Textbox(label="Input", lines=2, placeholder="Enter message...")
244
+ audio_file = gr.Audio(label="Audio", type="filepath", sources=["microphone", "upload"])
245
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  with gr.Row():
247
+ submit_btn = gr.Button("Send", variant="primary", scale=2)
248
+ clear_btn = gr.Button("Clear", scale=1)
249
+
250
+ # 事件绑定 - 函数将在启动时定义
251
+ # 直接绑定 chat 函数;不要传递外部的 `model_to_use`,chat 使用默认的 `MODEL_NAME` 或内部参数
 
252
  submit_btn.click(
253
+ fn=chat,
254
+ inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p],
255
+ outputs=[chatbot, status]
 
 
 
 
 
 
 
 
256
  )
257
+
258
  clear_btn.click(
259
  fn=lambda: ([], "", None),
260
  outputs=[chatbot, user_text, audio_file]
261
  )
262
 
 
263
  if __name__ == "__main__":
264
  import argparse
265
+ parser = argparse.ArgumentParser()
266
+ parser.add_argument("--host", default="0.0.0.0")
267
+ parser.add_argument("--port", type=int, default=7860)
268
+ parser.add_argument("--model", default=MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  args = parser.parse_args()
270
+
271
+ # 更新全局模型名称
272
+ if args.model:
273
+ MODEL_NAME = args.model
274
+
275
+ print(f"启动Gradio: http://{args.host}:{args.port}")
276
+ print(f"API地址: {API_BASE_URL}")
277
+ print(f"模型: {MODEL_NAME}")
278
+
279
+ demo.launch(server_name=args.host, server_port=args.port, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- gradio>=4.0.0
 
 
 
1
+ gradio>=4.0.0
2
+ httpx
3
+ huggingface_hub
start_services.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # 配置
5
+ MODEL_REPO="${MODEL_REPO:-stepfun-ai/Step-Audio-R1}"
6
+ MODEL_DIR="${MODEL_DIR:-/tmp/models/Step-Audio-R1}"
7
+ API_PORT="${API_PORT:-9999}"
8
+ GRADIO_PORT="${GRADIO_PORT:-7860}"
9
+
10
+ echo "Starting Step Audio R1 services..."
11
+ echo "Model: $MODEL_REPO"
12
+ echo "Model Dir: $MODEL_DIR"
13
+ echo "API Port: $API_PORT"
14
+
15
+ # 下载模型(如果需要)
16
+ if [[ ! -d "$MODEL_DIR" ]] || [[ ! -f "$MODEL_DIR/config.json" ]]; then
17
+ echo "Downloading model to: $MODEL_DIR"
18
+ mkdir -p "$MODEL_DIR"
19
+
20
+ if command -v hf &> /dev/null; then
21
+ hf download "$MODEL_REPO" --local-dir "$MODEL_DIR"
22
+ elif command -v huggingface-cli &> /dev/null; then
23
+ huggingface-cli download "$MODEL_REPO" --local-dir "$MODEL_DIR" --local-dir-use-symlinks False
24
+ else
25
+ echo "Neither hf nor huggingface-cli found. Skipping download."
26
+ exit 1
27
+ fi
28
+
29
+ echo "✓ Model downloaded"
30
+ else
31
+ echo "✓ Model already exists locally"
32
+ fi
33
+
34
+ # Step-Audio-R1 的 chat template
35
+ CHAT_TEMPLATE='{%- macro render_content(content) -%}{%- if content is string -%}{{- content.replace("<audio_patch>\\n", "<audio_patch>") -}}{%- elif content is mapping -%}{{- content["'"'"'value'"'"'] if '"'"'value'"'"' in content else content["'"'"'text'"'"'] -}}{%- elif content is iterable -%}{%- for item in content -%}{%- if item.type == '"'"'text'"'"' -%}{{- item["'"'"'value'"'"'] if '"'"'value'"'"' in item else item["'"'"'text'"'"'] -}}{%- elif item.type == '"'"'audio'"'"' -%}<audio_patch>{%- endif -%}{%- endfor -%}{%- endif -%}{%- endmacro -%}{%- if tools -%}{{- '"'"'<|BOT|>system\\n'"'"' -}}{%- if messages[0]["'"'"'role'"'"'] == '"'"'system'"'"' -%}{{- render_content(messages[0]["'"'"'content'"'"']) + '"'"'<|EOT|>'"'"' -}}{%- endif -%}{{- '"'"'<|BOT|>tool_json_schemas\\n'"'"' + tools|tojson + '"'"'<|EOT|>'"'"' -}}{%- else -%}{%- if messages[0]["'"'"'role'"'"'] == '"'"'system'"'"' -%}{{- '"'"'<|BOT|>system\\n'"'"' + render_content(messages[0]["'"'"'content'"'"']) + '"'"'<|EOT|>'"'"' -}}{%- endif -%}{%- endif -%}{%- for message in messages -%}{%- if message["role"] == "user" -%}{{- '"'"'<|BOT|>human\\n'"'"' + render_content(message["content"]) + '"'"'<|EOT|>'"'"' -}}{%- elif message["role"] == "assistant" -%}{{- '"'"'<|BOT|>assistant\\n'"'"' + (render_content(message["content"]) if message["content"] else '"'"''"'"') -}}{%- set is_last_assistant = true -%}{%- for m in messages[loop.index:] -%}{%- if m["role"] == "assistant" -%}{%- set is_last_assistant = false -%}{%- endif -%}{%- endfor -%}{%- if not is_last_assistant -%}{{- '"'"'<|EOT|>'"'"' -}}{%- endif -%}{%- elif message["role"] == "function_output" -%}{%- else -%}{%- if not (loop.first and message["role"] == "system") -%}{{- '"'"'<|BOT|>'"'"' + message["role"] + '"'"'\\n'"'"' + render_content(message["content"]) + '"'"'<|EOT|>'"'"' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{- '"'"'<|BOT|>assistant\\n'"'"' -}}{%- endif -%}'
36
+
37
+ # 后台启动 vLLM API
38
+ python3 -m vllm.entrypoints.openai.api_server \
39
+ --model "$MODEL_DIR" \
40
+ --port "$API_PORT" \
41
+ --host 0.0.0.0 \
42
+ --max-model-len 65536 \
43
+ --tensor-parallel-size 4 \
44
+ --gpu-memory-utilization 0.85 \
45
+ --trust-remote-code \
46
+ --interleave-mm-strings \
47
+ --chat-template "$CHAT_TEMPLATE" \
48
+ &
49
+
50
+ VLLM_PID=$!
51
+ echo "vLLM started (PID: $VLLM_PID)"
52
+
53
+ # 等待 vLLM 就绪
54
+ echo "Waiting for vLLM to be ready..."
55
+ for i in {1..30}; do
56
+ if curl -s "http://localhost:$API_PORT/v1/models" > /dev/null 2>&1; then
57
+ echo "✓ vLLM is ready (checked $i/30 times)"
58
+ break
59
+ fi
60
+
61
+ if [ $i -eq 30 ]; then
62
+ echo "❌ vLLM startup timeout after 60 seconds"
63
+ echo "Checking vLLM process:"
64
+ ps aux | grep "vllm.entrypoints.openai.api_server" || echo "vLLM process not found"
65
+ echo "Port $API_PORT status:"
66
+ netstat -tlnp | grep ":$API_PORT " || echo "Port $API_PORT not listening"
67
+ exit 1
68
+ fi
69
+
70
+ echo "Waiting for vLLM... ($i/30)"
71
+ sleep 2
72
+ done
73
+
74
+ # 启动 Gradio (前台运行)
75
+ export API_BASE_URL="http://localhost:$API_PORT/v1"
76
+ export MODEL_NAME="Step-Audio-R1"
77
+
78
+ python3 app.py --host 0.0.0.0 --port "$GRADIO_PORT"
79
+
80
+ # 清理
81
+ trap 'kill $VLLM_PID' EXIT