File size: 11,936 Bytes
01ee73c
 
 
 
 
 
 
 
098e8e5
01ee73c
 
 
 
fc5a2be
e9abbb0
098e8e5
 
 
 
 
01ee73c
 
e9abbb0
01ee73c
098e8e5
 
 
 
 
 
 
 
01ee73c
 
 
 
fc5a2be
01ee73c
 
 
 
 
 
 
 
fc5a2be
01ee73c
 
 
 
 
 
 
 
 
 
 
fc5a2be
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde2dcb
01ee73c
fde2dcb
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde2dcb
01ee73c
 
 
fde2dcb
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde2dcb
01ee73c
 
 
 
467194a
fde2dcb
467194a
01ee73c
 
 
 
467194a
 
01ee73c
 
 
 
 
 
fde2dcb
 
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467194a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ee73c
467194a
 
 
fde2dcb
 
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc5a2be
 
 
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
098e8e5
 
01ee73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理
"""

import base64
import os
import sys
import threading
import time
from typing import Optional, Tuple

import gradio as gr

# 检查命令行参数,在导入 vllm 之前确定是否启用
# 这样可以在没有安装 vllm 的情况下运行界面预览
if "--no-vllm" in sys.argv:
    os.environ["ENABLE_VLLM"] = "false"

# 检查是否启用 vLLM 模式
ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes")

if ENABLE_VLLM:
    try:
        from vllm import LLM, SamplingParams
    except ImportError:
        print("[WARNING] 无法导入 vllm,自动切换到界面预览模式")
        print("[INFO] 如需使用 vLLM,请先安装: pip install vllm")
        ENABLE_VLLM = False
        LLM = None
        SamplingParams = None
else:
    LLM = None
    SamplingParams = None
    print("[INFO] 运行在界面预览模式,不加载 vLLM")

# 默认配置,可通过环境变量或 CLI 覆盖
DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think")
DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID)
DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4"))
DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9"))
DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2")
DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think")

_llm: Optional[LLM] = None
_llm_lock = threading.Lock()
LLM_ARGS = {
    "model": DEFAULT_MODEL_PATH,
    "trust_remote_code": True,
    "tensor_parallel_size": DEFAULT_TP,
    "tokenizer_mode": DEFAULT_TOKENIZER_MODE,
    "max_model_len": DEFAULT_MAX_MODEL_LEN,
    "served_model_name": DEFAULT_SERVED_NAME,
    "gpu_memory_utilization": DEFAULT_GPU_UTIL,
}


def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]:
    """将音频文件编码为 base64"""
    if audio_path is None:
        return None
    
    try:
        with open(audio_path, "rb") as audio_file:
            audio_data = audio_file.read()
            audio_base64 = base64.b64encode(audio_data).decode('utf-8')
            # 尝试从文件扩展名推断格式
            ext = os.path.splitext(audio_path)[1].lower().lstrip('.')
            if not ext:
                ext = "wav"  # 默认格式
            return {
                "data": audio_base64,
                "format": ext
            }
    except Exception as e:
        print(f"Error encoding audio: {e}")
        return None


def format_messages(
    system_prompt: str,
    chat_history: list,
    user_text: str,
    audio_file: Optional[str]
) -> list:
    """格式化消息为 OpenAI API 格式"""
    messages = []
    
    # 添加 system prompt
    if system_prompt and system_prompt.strip():
        messages.append({
            "role": "system",
            "content": system_prompt.strip()
        })
    
    # 添加历史对话
    for human, assistant in chat_history:
        if human:
            messages.append({"role": "user", "content": human})
        if assistant:
            messages.append({"role": "assistant", "content": assistant})
    
    # 添加当前用户输入
    content_parts = []
    
    # 添加文本输入
    if user_text and user_text.strip():
        content_parts.append({
            "type": "text",
            "text": user_text.strip()
        })
    
    # 添加音频输入
    if audio_file:
        audio_data = encode_audio_to_base64(audio_file)
        if audio_data:
            content_parts.append({
                "type": "input_audio",
                "input_audio": audio_data
            })
    
    if content_parts:
        # 如果只有一个文本部分,直接使用字符串
        if len(content_parts) == 1 and content_parts[0]["type"] == "text":
            messages.append({
                "role": "user",
                "content": content_parts[0]["text"]
            })
        else:
            messages.append({
                "role": "user",
                "content": content_parts
            })
    
    return messages


def chat_predict(
    system_prompt: str,
    user_text: str,
    audio_file: Optional[str],
    chat_history: list,
    max_tokens: int,
    temperature: float,
    top_p: float
) -> Tuple[list, str]:
    """调用本地 vLLM LLM 完成推理"""
    if not user_text and not audio_file:
        return chat_history, "⚠ 请提供文本或音频输入"
    
    # 如果是预览模式,返回模拟响应
    if not ENABLE_VLLM:
        user_display = user_text if user_text else "[音频输入]"
        mock_response = f"这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}"
        chat_history.append((user_display, mock_response))
        return chat_history, ""
    
    messages = format_messages(system_prompt, chat_history, user_text, audio_file)
    if not messages:
        return chat_history, "⚠ 无有效输入"
    
    try:
        llm = _get_llm()
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        start_time = time.time()
        outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False)
        latency = time.time() - start_time
        
        if not outputs or not outputs[0].outputs:
            return chat_history, "⚠ 模型未返回结果"
        
        assistant_message = outputs[0].outputs[0].text
        user_display = user_text if user_text else "[音频输入]"
        chat_history.append((user_display, assistant_message))
        return chat_history, ""
    except Exception as e:
        import traceback
        traceback.print_exc()
        return chat_history, ""


def _get_llm() -> LLM:
    """单例方式初始化 LLM"""
    if not ENABLE_VLLM:
        raise RuntimeError("vLLM 未启用,无法加载模型")
    
    global _llm
    if _llm is not None:
        return _llm
    
    with _llm_lock:
        if _llm is not None:
            return _llm
        print(f"[LLM] 初始化中,参数: {LLM_ARGS}")
        _llm = LLM(**LLM_ARGS)
    return _llm


def _set_llm_args(**kwargs) -> None:
    """更新 LLM 初始化参数"""
    global LLM_ARGS, _llm
    LLM_ARGS = kwargs
    _llm = None  # 确保使用新配置重新加载





# 构建 Gradio 界面
with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Step Audio R1 Demo
        """
    )
    
    with gr.Row():
        # 左侧:参数配置
        with gr.Column(scale=1):
            gr.Markdown("### 配置")
            
            system_prompt = gr.Textbox(
                label="System Prompt",
                placeholder="输入系统提示词...",
                lines=4,
                value="You are an expert in audio analysis, please analyze the audio content and answer the questions accurately"
            )
            
            with gr.Row():
                max_tokens = gr.Slider(
                    label="Max Tokens",
                    minimum=1,
                    maximum=16384,
                    value=8192,
                    step=1
                )
            
            with gr.Row():
                temperature = gr.Slider(
                    label="Temperature",
                    minimum=0.0,
                    maximum=2.0,
                    value=0.7,
                    step=0.1
                )
                
                top_p = gr.Slider(
                    label="Top P",
                    minimum=0.0,
                    maximum=1.0,
                    value=0.9,
                    step=0.05
                )
        
        # 右侧:对话和输入
        with gr.Column(scale=1):
            gr.Markdown("### 对话")
            chatbot = gr.Chatbot(
                label="聊天历史",
                height=400,
                show_copy_button=True
            )
            
            user_text = gr.Textbox(
                label="文本输入",
                placeholder="输入您的消息...",
                lines=2
            )
            
            audio_file = gr.Audio(
                label="音频输入",
                type="filepath",
                sources=["microphone", "upload"]
            )
            
            with gr.Row():
                submit_btn = gr.Button("提交", variant="primary", size="lg")
                clear_btn = gr.Button("清空", variant="secondary")
            
            status_text = gr.Textbox(label="状态", interactive=False, visible=False)
    
    # 事件绑定
    submit_btn.click(
        fn=chat_predict,
        inputs=[
            system_prompt,
            user_text,
            audio_file,
            chatbot,
            max_tokens,
            temperature,
            top_p
        ],
        outputs=[chatbot, status_text]
    )
    
    clear_btn.click(
        fn=lambda: ([], "", None),
        outputs=[chatbot, user_text, audio_file]
    )


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface")
    parser.add_argument(
        "--host",
        type=str,
        default="0.0.0.0",
        help="服务器主机地址"
    )
    parser.add_argument(
        "--port",
        type=int,
        default=7860,
        help="服务器端口"
    )
    parser.add_argument(
        "--model",
        type=str,
        default=DEFAULT_MODEL_PATH,
        help="模型名称或本地路径"
    )
    parser.add_argument(
        "--tensor-parallel-size",
        type=int,
        default=DEFAULT_TP,
        help="张量并行数量"
    )
    parser.add_argument(
        "--max-model-len",
        type=int,
        default=DEFAULT_MAX_MODEL_LEN,
        help="最大上下文长度"
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=DEFAULT_GPU_UTIL,
        help="GPU 显存利用率"
    )
    parser.add_argument(
        "--tokenizer-mode",
        type=str,
        default=DEFAULT_TOKENIZER_MODE,
        help="tokenizer 模式"
    )
    parser.add_argument(
        "--served-model-name",
        type=str,
        default=DEFAULT_SERVED_NAME,
        help="对外暴露的模型名称"
    )
    parser.add_argument(
        "--no-vllm",
        action="store_true",
        help="禁用 vLLM,仅启动界面预览模式"
    )
    
    args = parser.parse_args()
    
    # --no-vllm 参数已在文件开头处理,这里只是提示
    if args.no_vllm and not ENABLE_VLLM:
        print("[INFO] 已禁用 vLLM,运行在界面预览模式")
    
    _set_llm_args(
        model=args.model,
        trust_remote_code=True,
        tensor_parallel_size=args.tensor_parallel_size,
        tokenizer_mode=args.tokenizer_mode,
        max_model_len=args.max_model_len,
        served_model_name=args.served_model_name,
        gpu_memory_utilization=args.gpu_memory_utilization,
    )
    
    print("==========================================")
    print("Step Audio 2 Gradio Chat")
    if ENABLE_VLLM:
        print(f"模式: vLLM 推理模式")
        print(f"模型: {args.model}")
        print(f"Tensor Parallel Size: {args.tensor_parallel_size}")
        print(f"Max Model Len: {args.max_model_len}")
        print(f"Tokenizer Mode: {args.tokenizer_mode}")
        print(f"Served Model Name: {args.served_model_name}")
    else:
        print(f"模式: 界面预览模式(无 vLLM)")
    print(f"Gradio 地址: http://{args.host}:{args.port}")
    print("==========================================")
    
    demo.queue().launch(
        server_name=args.host,
        server_port=args.port,
        share=False
    )