#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理 """ import base64 import os import threading import time from typing import Optional, Tuple import gradio as gr # 检查是否启用 vLLM 模式 ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes") if ENABLE_VLLM: from vllm import LLM, SamplingParams else: LLM = None SamplingParams = None print("[INFO] 运行在界面预览模式,不加载 vLLM") # 默认配置,可通过环境变量或 CLI 覆盖 DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think") DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID) DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4")) DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")) DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2") DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think") _llm: Optional[LLM] = None _llm_lock = threading.Lock() LLM_ARGS = { "model": DEFAULT_MODEL_PATH, "trust_remote_code": True, "tensor_parallel_size": DEFAULT_TP, "tokenizer_mode": DEFAULT_TOKENIZER_MODE, "max_model_len": DEFAULT_MAX_MODEL_LEN, "served_model_name": DEFAULT_SERVED_NAME, "gpu_memory_utilization": DEFAULT_GPU_UTIL, } def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]: """将音频文件编码为 base64""" if audio_path is None: return None try: with open(audio_path, "rb") as audio_file: audio_data = audio_file.read() audio_base64 = base64.b64encode(audio_data).decode('utf-8') # 尝试从文件扩展名推断格式 ext = os.path.splitext(audio_path)[1].lower().lstrip('.') if not ext: ext = "wav" # 默认格式 return { "data": audio_base64, "format": ext } except Exception as e: print(f"Error encoding audio: {e}") return None def format_messages( system_prompt: str, chat_history: list, user_text: str, audio_file: Optional[str] ) -> list: """格式化消息为 OpenAI API 格式""" messages = [] # 添加 system prompt if system_prompt and system_prompt.strip(): messages.append({ "role": "system", "content": system_prompt.strip() }) # 添加历史对话 for human, assistant in chat_history: if human: messages.append({"role": "user", "content": human}) if assistant: messages.append({"role": "assistant", "content": assistant}) # 添加当前用户输入 content_parts = [] # 添加文本输入 if user_text and user_text.strip(): content_parts.append({ "type": "text", "text": user_text.strip() }) # 添加音频输入 if audio_file: audio_data = encode_audio_to_base64(audio_file) if audio_data: content_parts.append({ "type": "input_audio", "input_audio": audio_data }) if content_parts: # 如果只有一个文本部分,直接使用字符串 if len(content_parts) == 1 and content_parts[0]["type"] == "text": messages.append({ "role": "user", "content": content_parts[0]["text"] }) else: messages.append({ "role": "user", "content": content_parts }) return messages def chat_predict( system_prompt: str, user_text: str, audio_file: Optional[str], chat_history: list, max_tokens: int, temperature: float, top_p: float ) -> Tuple[list, str]: """调用本地 vLLM LLM 完成推理""" if not user_text and not audio_file: return chat_history, "⚠ 请提供文本或音频输入" # 如果是预览模式,返回模拟响应 if not ENABLE_VLLM: user_display = user_text if user_text else "[音频输入]" mock_response = f"[预览模式] 这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}" chat_history.append((user_display, mock_response)) return chat_history, "✓ 预览模式(未启用 vLLM)" messages = format_messages(system_prompt, chat_history, user_text, audio_file) if not messages: return chat_history, "⚠ 无有效输入" try: llm = _get_llm() sampling_params = SamplingParams( max_tokens=max_tokens, temperature=temperature, top_p=top_p, ) start_time = time.time() outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False) latency = time.time() - start_time if not outputs or not outputs[0].outputs: return chat_history, "⚠ 模型未返回结果" assistant_message = outputs[0].outputs[0].text user_display = user_text if user_text else "[音频输入]" chat_history.append((user_display, assistant_message)) status = f"✓ 推理完成(耗时 {latency:.2f}s)" return chat_history, status except Exception as e: import traceback traceback.print_exc() return chat_history, f"✗ 推理失败: {e}" def _get_llm() -> LLM: """单例方式初始化 LLM""" if not ENABLE_VLLM: raise RuntimeError("vLLM 未启用,无法加载模型") global _llm if _llm is not None: return _llm with _llm_lock: if _llm is not None: return _llm print(f"[LLM] 初始化中,参数: {LLM_ARGS}") _llm = LLM(**LLM_ARGS) return _llm def _set_llm_args(**kwargs) -> None: """更新 LLM 初始化参数""" global LLM_ARGS, _llm LLM_ARGS = kwargs _llm = None # 确保使用新配置重新加载 def check_model_status() -> str: """返回模型当前加载状态""" if not ENABLE_VLLM: return "⚙ 界面预览模式(vLLM 未启用)" model_path = LLM_ARGS["model"] if _llm is None: return f"等待加载:{model_path}" return f"✓ 已加载模型:{model_path}" def warmup_model() -> str: """主动加载模型""" if not ENABLE_VLLM: return "⚙ 界面预览模式(vLLM 未启用)" try: _get_llm() return check_model_status() except Exception as exc: import traceback traceback.print_exc() return f"✗ 模型加载失败: {exc}" # 构建 Gradio 界面 with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Step Audio 2 Chat Interface 支持文本和音频输入的聊天界面,直接在本地 vLLM 引擎上推理。 """ ) # 模型状态 with gr.Row(): status_text = gr.Textbox( label="模型状态", value="检查中...", interactive=False ) check_btn = gr.Button("加载/检查模型", variant="secondary") with gr.Row(): # 左侧:输入区域 with gr.Column(scale=1): gr.Markdown("### 输入设置") system_prompt = gr.Textbox( label="System Prompt", placeholder="输入系统提示词...", lines=3, value="" ) user_text = gr.Textbox( label="文本输入", placeholder="输入您的消息...", lines=3 ) audio_file = gr.Audio( label="音频输入", type="filepath", sources=["upload", "microphone"] ) with gr.Row(): max_tokens = gr.Slider( label="Max Tokens", minimum=1, maximum=8192, value=2048, step=1 ) with gr.Row(): temperature = gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1 ) top_p = gr.Slider( label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.05 ) submit_btn = gr.Button("提交", variant="primary", size="lg") clear_btn = gr.Button("清空", variant="secondary") # 右侧:聊天历史 with gr.Column(scale=1): gr.Markdown("### 聊天历史") chatbot = gr.Chatbot( label="对话", height=600, show_copy_button=True ) # 事件绑定 check_btn.click(fn=warmup_model, outputs=status_text) submit_btn.click( fn=chat_predict, inputs=[ system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p ], outputs=[chatbot, status_text] ) clear_btn.click( fn=lambda: ([], "", None), outputs=[chatbot, user_text, audio_file] ) # 页面加载时显示状态 demo.load(fn=check_model_status, outputs=status_text) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface") parser.add_argument( "--host", type=str, default="0.0.0.0", help="服务器主机地址" ) parser.add_argument( "--port", type=int, default=7860, help="服务器端口" ) parser.add_argument( "--model", type=str, default=DEFAULT_MODEL_PATH, help="模型名称或本地路径" ) parser.add_argument( "--tensor-parallel-size", type=int, default=DEFAULT_TP, help="张量并行数量" ) parser.add_argument( "--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN, help="最大上下文长度" ) parser.add_argument( "--gpu-memory-utilization", type=float, default=DEFAULT_GPU_UTIL, help="GPU 显存利用率" ) parser.add_argument( "--tokenizer-mode", type=str, default=DEFAULT_TOKENIZER_MODE, help="tokenizer 模式" ) parser.add_argument( "--served-model-name", type=str, default=DEFAULT_SERVED_NAME, help="对外暴露的模型名称" ) parser.add_argument( "--no-vllm", action="store_true", help="禁用 vLLM,仅启动界面预览模式" ) args = parser.parse_args() # 如果指定了 --no-vllm,覆盖环境变量 if args.no_vllm: global ENABLE_VLLM ENABLE_VLLM = False print("[INFO] 已禁用 vLLM,运行在界面预览模式") _set_llm_args( model=args.model, trust_remote_code=True, tensor_parallel_size=args.tensor_parallel_size, tokenizer_mode=args.tokenizer_mode, max_model_len=args.max_model_len, served_model_name=args.served_model_name, gpu_memory_utilization=args.gpu_memory_utilization, ) print("==========================================") print("Step Audio 2 Gradio Chat") if ENABLE_VLLM: print(f"模式: vLLM 推理模式") print(f"模型: {args.model}") print(f"Tensor Parallel Size: {args.tensor_parallel_size}") print(f"Max Model Len: {args.max_model_len}") print(f"Tokenizer Mode: {args.tokenizer_mode}") print(f"Served Model Name: {args.served_model_name}") else: print(f"模式: 界面预览模式(无 vLLM)") print(f"Gradio 地址: http://{args.host}:{args.port}") print("==========================================") demo.queue().launch( server_name=args.host, server_port=args.port, share=False )