#!/usr/bin/env python3 """ Step Audio R1 vLLM Gradio Interface """ import base64 import json import os import io from pydub import AudioSegment import gradio as gr import httpx API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1") def process_audio(audio_path): """ Process audio: convert to wav, split if > 25s. Returns a list of base64 encoded wav strings. """ if not audio_path or not os.path.exists(audio_path): return [] try: # Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed) audio = AudioSegment.from_file(audio_path) # Split into chunks of 25 seconds (25000 ms) chunk_length_ms = 25000 chunks = [] if len(audio) > chunk_length_ms: for i in range(0, len(audio), chunk_length_ms): chunk = audio[i:i + chunk_length_ms] chunks.append(chunk) else: chunks.append(audio) # Convert chunks to base64 wav audio_data_list = [] for chunk in chunks: buffer = io.BytesIO() chunk.export(buffer, format="wav") encoded = base64.b64encode(buffer.getvalue()).decode() audio_data_list.append(encoded) return audio_data_list except Exception as e: print(f"[DEBUG] Audio processing error: {e}") return [] def format_messages(system, history, user_text, audio_data_list=None): """Format message list""" messages = [] if system: messages.append({"role": "system", "content": system}) if not history: history = [] # 处理历史记录 for item in history: # Filter out thinking process messages metadata = item.get("metadata") if isinstance(item, dict) else getattr(item, "metadata", None) if metadata and isinstance(metadata, dict) and metadata.get("title") == "⏳ Thinking Process": continue role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None) content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None) if not role or content is None: continue # Check for Audio is_audio = not isinstance(content, list) and content["component"] == "audio" if is_audio: audio_path = content["value"]["path"] if audio_path and os.path.exists(audio_path): try: item_audio_data_list = process_audio(audio_path) new_content = [] for audio_data in item_audio_data_list: new_content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) messages.append({"role": role, "content": new_content}) except Exception as e: print(f"[ERROR] Failed to process history audio: {e}") elif isinstance(content, str): messages.append({"role": role, "content": content}) elif isinstance(content, list): # Assume it's already a list of parts or mixed safe_content = [] for c in content: # Check for Audio in list is_c_audio = c.get('component', None) == "audio" if is_c_audio: audio_path = c["value"]["path"] if audio_path and os.path.exists(audio_path): try: item_audio_data_list = process_audio(audio_path) for audio_data in item_audio_data_list: safe_content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) except Exception as e: print(f"[ERROR] Failed to process history audio in list: {e}") elif isinstance(c, dict): safe_content.append(c) elif isinstance(c, str): safe_content.append({"type": "text", "text": c}) messages.append({"role": role, "content": safe_content}) # 添加当前用户消息 if user_text and audio_data_list: content = [] for audio_data in audio_data_list: content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) content.append({ "type": "text", "text": user_text }) messages.append({ "role": "user", "content": content }) elif user_text: messages.append({"role": "user", "content": user_text}) elif audio_data_list: content = [] messages.append({ "role": "user", "content": content }) for audio_data in audio_data_list: content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) return messages def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, model_name=None): """Chat function""" # If model is not specified, use global configuration if model_name is None: model_name = MODEL_NAME if not user_text and not audio_file: yield history or [], "Please enter text or upload audio" return # Ensure history is a list and formatted correctly history = history or [] clean_history = [] for item in history: if isinstance(item, dict) and 'role' in item and 'content' in item: clean_history.append(item) elif hasattr(item, "role") and hasattr(item, "content"): # Keep ChatMessage object clean_history.append(item) history = clean_history # Process audio audio_data_list = [] if audio_file: audio_data_list = process_audio(audio_file) messages = format_messages(system_prompt, history, user_text, audio_data_list) if not messages: yield history or [], "Invalid input" return # Debug: Print message format debug_messages = [] for msg in messages: if isinstance(msg, dict) and isinstance(msg.get("content"), list): new_content = [] for item in msg["content"]: if isinstance(item, dict) and item.get("type") == "input_audio": item_copy = item.copy() if "input_audio" in item_copy: audio_info = item_copy["input_audio"].copy() if "data" in audio_info: audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]" item_copy["input_audio"] = audio_info new_content.append(item_copy) else: new_content.append(item) msg_copy = msg.copy() msg_copy["content"] = new_content debug_messages.append(msg_copy) else: debug_messages.append(msg) print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}") # Update history with user message immediately if audio_file: # 1. Add audio message history.append({"role": "user", "content": gr.Audio(audio_file)}) # 2. If text exists, add text message if user_text: history.append({"role": "user", "content": user_text}) else: # Text only history.append({"role": "user", "content": user_text}) # Add thinking placeholder history.append(gr.ChatMessage( role="assistant", content="", metadata={"title": "⏳ Thinking Process"} )) yield history, "Generating..." try: with httpx.Client(base_url=API_BASE_URL, timeout=120) as client: response = client.post("/chat/completions", json={ "model": model_name, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "stream": True, "repetition_penalty": 1.07, "stop_token_ids": [151665] }) if response.status_code != 200: error_msg = f"❌ API Error {response.status_code}" if response.status_code == 404: error_msg += " - vLLM service not ready" elif response.status_code == 400: error_msg += " - Bad request" elif response.status_code == 500: error_msg += " - Model error" yield history, error_msg return # Process streaming response buffer = "" is_thinking = True for line in response.iter_lines(): if not line: continue # Ensure line is string format if isinstance(line, bytes): line = line.decode('utf-8') else: line = str(line) if line.startswith('data: '): data_str = line[6:] if data_str.strip() == '[DONE]': break try: data = json.loads(data_str) if 'choices' in data and len(data['choices']) > 0: delta = data['choices'][0].get('delta', {}) if 'content' in delta: content = delta['content'] buffer += content if is_thinking: if "" in buffer: is_thinking = False parts = buffer.split("", 1) think_content = parts[0] response_content = parts[1] if think_content.startswith(""): think_content = think_content[len(""):].strip() # Update thinking message history[-1].content = think_content # Add response message history.append({"role": "assistant", "content": response_content}) else: # Update thinking message current_think = buffer if current_think.startswith(""): current_think = current_think[len(""):] history[-1].content = current_think else: # Already split, just update response message parts = buffer.split("", 1) response_content = parts[1] history[-1]["content"] = response_content yield history, "" except json.JSONDecodeError: continue except httpx.ConnectError: yield history, "❌ Cannot connect to vLLM API" except Exception as e: yield history, f"❌ Error: {str(e)}" # Gradio Interface with gr.Blocks(title="Step Audio R1") as demo: gr.Markdown("# Step Audio R1 Chat") with gr.Row(): # Left Configuration with gr.Column(scale=1): with gr.Accordion("Configuration", open=True): system_prompt = gr.Textbox( label="System Prompt", lines=2, value="你是一个语音助手,你有非常丰富的音频处理经验。" ) max_tokens = gr.Slider(1, 7192, value=1024, label="Max Tokens") temperature = gr.Slider(0.0, 2.0, value=0.7, label="Temperature") top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P") status = gr.Textbox(label="Status", interactive=False) # Right Chat with gr.Column(scale=2): chatbot = gr.Chatbot(label="Chat History", height=450) user_text = gr.Textbox(label="Input", lines=2, placeholder="Enter message...") audio_file = gr.Audio(label="Audio", type="filepath", sources=["microphone", "upload"]) with gr.Row(): submit_btn = gr.Button("Send", variant="primary", scale=2) clear_btn = gr.Button("Clear", scale=1) submit_btn.click( fn=chat, inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p], outputs=[chatbot, status] ) clear_btn.click( fn=lambda: ([], "", None), outputs=[chatbot, user_text, audio_file] ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--model", default=MODEL_NAME) args = parser.parse_args() # 更新全局模型名称 if args.model: MODEL_NAME = args.model print(f"启动Gradio: http://{args.host}:{args.port}") print(f"API地址: {API_BASE_URL}") print(f"模型: {MODEL_NAME}") demo.launch(server_name=args.host, server_port=args.port, share=False)