Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Step Audio R1 vLLM Gradio Interface | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import io | |
| from pydub import AudioSegment | |
| import gradio as gr | |
| import httpx | |
| API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1") | |
| def process_audio(audio_path): | |
| """ | |
| Process audio: convert to wav, split if > 25s. | |
| Returns a list of base64 encoded wav strings. | |
| """ | |
| if not audio_path or not os.path.exists(audio_path): | |
| return [] | |
| try: | |
| # Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed) | |
| audio = AudioSegment.from_file(audio_path) | |
| # Split into chunks of 25 seconds (25000 ms) | |
| chunk_length_ms = 25000 | |
| chunks = [] | |
| if len(audio) > chunk_length_ms: | |
| for i in range(0, len(audio), chunk_length_ms): | |
| chunk = audio[i:i + chunk_length_ms] | |
| chunks.append(chunk) | |
| else: | |
| chunks.append(audio) | |
| # Convert chunks to base64 wav | |
| audio_data_list = [] | |
| for chunk in chunks: | |
| buffer = io.BytesIO() | |
| chunk.export(buffer, format="wav") | |
| encoded = base64.b64encode(buffer.getvalue()).decode() | |
| audio_data_list.append(encoded) | |
| return audio_data_list | |
| except Exception as e: | |
| print(f"[DEBUG] Audio processing error: {e}") | |
| return [] | |
| def format_messages(system, history, user_text, audio_data_list=None): | |
| """Format message list""" | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| if not history: | |
| history = [] | |
| # 处理历史记录 | |
| for item in history: | |
| # Filter out thinking process messages | |
| metadata = item.get("metadata") if isinstance(item, dict) else getattr(item, "metadata", None) | |
| if metadata and isinstance(metadata, dict) and metadata.get("title") == "⏳ Thinking Process": | |
| continue | |
| role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None) | |
| content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None) | |
| if not role or content is None: | |
| continue | |
| # Check for Audio | |
| is_audio = not isinstance(content, list) and content.get("component", None) == "audio" | |
| if is_audio: | |
| audio_path = content["value"]["path"] | |
| if audio_path and os.path.exists(audio_path): | |
| try: | |
| item_audio_data_list = process_audio(audio_path) | |
| new_content = [] | |
| for audio_data in item_audio_data_list: | |
| new_content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| messages.append({"role": role, "content": new_content}) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to process history audio: {e}") | |
| elif isinstance(content, str): | |
| messages.append({"role": role, "content": content}) | |
| elif isinstance(content, list): | |
| # Assume it's already a list of parts or mixed | |
| safe_content = [] | |
| for c in content: | |
| # Check for Audio in list | |
| is_c_audio = c.get('component', None) == "audio" | |
| if is_c_audio: | |
| audio_path = c["value"]["path"] | |
| if audio_path and os.path.exists(audio_path): | |
| try: | |
| item_audio_data_list = process_audio(audio_path) | |
| for audio_data in item_audio_data_list: | |
| safe_content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to process history audio in list: {e}") | |
| elif isinstance(c, dict): | |
| safe_content.append(c) | |
| elif isinstance(c, str): | |
| safe_content.append({"type": "text", "text": c}) | |
| messages.append({"role": role, "content": safe_content}) | |
| # 添加当前用户消息 | |
| if user_text and audio_data_list: | |
| content = [] | |
| for audio_data in audio_data_list: | |
| content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| content.append({ | |
| "type": "text", | |
| "text": user_text | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": content | |
| }) | |
| elif user_text: | |
| messages.append({"role": "user", "content": user_text}) | |
| elif audio_data_list: | |
| content = [] | |
| messages.append({ | |
| "role": "user", | |
| "content": content | |
| }) | |
| for audio_data in audio_data_list: | |
| content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| return messages | |
| def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, model_name=None): | |
| """Chat function""" | |
| # If model is not specified, use global configuration | |
| if model_name is None: | |
| model_name = MODEL_NAME | |
| if not user_text and not audio_file: | |
| yield history or [], "Please enter text or upload audio" | |
| return | |
| # Ensure history is a list and formatted correctly | |
| history = history or [] | |
| clean_history = [] | |
| for item in history: | |
| if isinstance(item, dict) and 'role' in item and 'content' in item: | |
| clean_history.append(item) | |
| elif hasattr(item, "role") and hasattr(item, "content"): | |
| # Keep ChatMessage object | |
| clean_history.append(item) | |
| history = clean_history | |
| # Process audio | |
| audio_data_list = [] | |
| if audio_file: | |
| audio_data_list = process_audio(audio_file) | |
| messages = format_messages(system_prompt, history, user_text, audio_data_list) | |
| if not messages: | |
| yield history or [], "Invalid input" | |
| return | |
| # Debug: Print message format | |
| debug_messages = [] | |
| for msg in messages: | |
| if isinstance(msg, dict) and isinstance(msg.get("content"), list): | |
| new_content = [] | |
| for item in msg["content"]: | |
| if isinstance(item, dict) and item.get("type") == "input_audio": | |
| item_copy = item.copy() | |
| if "input_audio" in item_copy: | |
| audio_info = item_copy["input_audio"].copy() | |
| if "data" in audio_info: | |
| audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]" | |
| item_copy["input_audio"] = audio_info | |
| new_content.append(item_copy) | |
| else: | |
| new_content.append(item) | |
| msg_copy = msg.copy() | |
| msg_copy["content"] = new_content | |
| debug_messages.append(msg_copy) | |
| else: | |
| debug_messages.append(msg) | |
| print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}") | |
| # Update history with user message immediately | |
| if audio_file: | |
| # 1. Add audio message | |
| history.append({"role": "user", "content": gr.Audio(audio_file)}) | |
| # 2. If text exists, add text message | |
| if user_text: | |
| history.append({"role": "user", "content": user_text}) | |
| else: | |
| # Text only | |
| history.append({"role": "user", "content": user_text}) | |
| # Add thinking placeholder | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content="", | |
| metadata={"title": "⏳ Thinking Process"} | |
| )) | |
| yield history, "Generating..." | |
| try: | |
| with httpx.Client(base_url=API_BASE_URL, timeout=120) as client: | |
| # Use client.stream for better streaming control | |
| with client.stream("POST", "/chat/completions", json={ | |
| "model": model_name, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": True, | |
| "repetition_penalty": 1.07, | |
| "stop_token_ids": [151665] | |
| }) as response: | |
| if response.status_code != 200: | |
| error_msg = f"❌ API Error {response.status_code}" | |
| if response.status_code == 404: | |
| error_msg += " - vLLM service not ready" | |
| elif response.status_code == 400: | |
| error_msg += " - Bad request" | |
| elif response.status_code == 500: | |
| error_msg += " - Model error" | |
| yield history, error_msg | |
| return | |
| # Process streaming response | |
| buffer = "" | |
| is_thinking = True | |
| print("[DEBUG] Start receiving stream...") | |
| for line in response.iter_lines(): | |
| if not line: | |
| continue | |
| # Ensure line is string format | |
| if isinstance(line, bytes): | |
| line = line.decode('utf-8') | |
| else: | |
| line = str(line) | |
| if line.startswith('data: '): | |
| data_str = line[6:] | |
| if data_str.strip() == '[DONE]': | |
| print("[DEBUG] Stream finished [DONE]") | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| if 'choices' in data and len(data['choices']) > 0: | |
| delta = data['choices'][0].get('delta', {}) | |
| if 'content' in delta: | |
| content = delta['content'] | |
| buffer += content | |
| if is_thinking: | |
| if "</think>" in buffer: | |
| is_thinking = False | |
| parts = buffer.split("</think>", 1) | |
| think_content = parts[0] | |
| response_content = parts[1] | |
| if think_content.startswith("<think>"): | |
| think_content = think_content[len("<think>"):].strip() | |
| # Update thinking message | |
| history[-1].content = think_content | |
| # Add response message | |
| history.append({"role": "assistant", "content": response_content}) | |
| else: | |
| # Update thinking message | |
| current_think = buffer | |
| if current_think.startswith("<think>"): | |
| current_think = current_think[len("<think>"):] | |
| history[-1].content = current_think | |
| else: | |
| # Already split, just update response message | |
| parts = buffer.split("</think>", 1) | |
| response_content = parts[1] | |
| history[-1]["content"] = response_content | |
| yield history, "" | |
| except json.JSONDecodeError: | |
| continue | |
| except httpx.ConnectError: | |
| yield history, "❌ Cannot connect to vLLM API" | |
| except Exception as e: | |
| yield history, f"❌ Error: {str(e)}" | |
| # Gradio Interface | |
| with gr.Blocks(title="Step Audio R1") as demo: | |
| gr.Markdown("# Step Audio R1 Chat") | |
| with gr.Row(): | |
| # Left Configuration | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Configuration", open=True): | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| lines=2, | |
| value="你是一个语音助手,你有非常丰富的音频处理经验。" | |
| ) | |
| max_tokens = gr.Slider(1, 7192, value=1024, label="Max Tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| # Right Chat | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Chat History", height=450) | |
| user_text = gr.Textbox(label="Input", lines=2, placeholder="Enter message...") | |
| audio_file = gr.Audio(label="Audio", type="filepath", sources=["microphone", "upload"]) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary", scale=2) | |
| clear_btn = gr.Button("Clear", scale=1) | |
| submit_btn.click( | |
| fn=chat, | |
| inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot, status] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], "", None), | |
| outputs=[chatbot, user_text, audio_file] | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--model", default=MODEL_NAME) | |
| args = parser.parse_args() | |
| # 更新全局模型名称 | |
| if args.model: | |
| MODEL_NAME = args.model | |
| print(f"启动Gradio: http://{args.host}:{args.port}") | |
| print(f"API地址: {API_BASE_URL}") | |
| print(f"模型: {MODEL_NAME}") | |
| demo.launch(server_name=args.host, server_port=args.port, share=False) |