Step-Audio-R1 / app.py
moevis's picture
Update app.py
e6f110c verified
raw
history blame
15.2 kB
#!/usr/bin/env python3
"""
Step Audio R1 vLLM Gradio Interface
"""
import base64
import json
import os
import io
from pydub import AudioSegment
import gradio as gr
import httpx
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
def process_audio(audio_path):
"""
Process audio: convert to wav, split if > 25s.
Returns a list of base64 encoded wav strings.
"""
if not audio_path or not os.path.exists(audio_path):
return []
try:
# Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed)
audio = AudioSegment.from_file(audio_path)
# Split into chunks of 25 seconds (25000 ms)
chunk_length_ms = 25000
chunks = []
if len(audio) > chunk_length_ms:
for i in range(0, len(audio), chunk_length_ms):
chunk = audio[i:i + chunk_length_ms]
chunks.append(chunk)
else:
chunks.append(audio)
# Convert chunks to base64 wav
audio_data_list = []
for chunk in chunks:
buffer = io.BytesIO()
chunk.export(buffer, format="wav")
encoded = base64.b64encode(buffer.getvalue()).decode()
audio_data_list.append(encoded)
return audio_data_list
except Exception as e:
print(f"[DEBUG] Audio processing error: {e}")
return []
def format_messages(system, history, user_text, audio_data_list=None):
"""Format message list"""
messages = []
if system:
messages.append({"role": "system", "content": system})
if not history:
history = []
# 处理历史记录
for item in history:
# Filter out thinking process messages
metadata = item.get("metadata") if isinstance(item, dict) else getattr(item, "metadata", None)
if metadata and isinstance(metadata, dict) and metadata.get("title") == "⏳ Thinking Process":
continue
role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None)
content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None)
if not role or content is None:
continue
# Check for Audio
is_audio = not isinstance(content, list) and content.get("component", None) == "audio"
if is_audio:
audio_path = content["value"]["path"]
if audio_path and os.path.exists(audio_path):
try:
item_audio_data_list = process_audio(audio_path)
new_content = []
for audio_data in item_audio_data_list:
new_content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
messages.append({"role": role, "content": new_content})
except Exception as e:
print(f"[ERROR] Failed to process history audio: {e}")
elif isinstance(content, str):
messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Assume it's already a list of parts or mixed
safe_content = []
for c in content:
# Check for Audio in list
is_c_audio = c.get('component', None) == "audio"
if is_c_audio:
audio_path = c["value"]["path"]
if audio_path and os.path.exists(audio_path):
try:
item_audio_data_list = process_audio(audio_path)
for audio_data in item_audio_data_list:
safe_content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
except Exception as e:
print(f"[ERROR] Failed to process history audio in list: {e}")
elif isinstance(c, dict):
safe_content.append(c)
elif isinstance(c, str):
safe_content.append({"type": "text", "text": c})
messages.append({"role": role, "content": safe_content})
# 添加当前用户消息
if user_text and audio_data_list:
content = []
for audio_data in audio_data_list:
content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
content.append({
"type": "text",
"text": user_text
})
messages.append({
"role": "user",
"content": content
})
elif user_text:
messages.append({"role": "user", "content": user_text})
elif audio_data_list:
content = []
messages.append({
"role": "user",
"content": content
})
for audio_data in audio_data_list:
content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
return messages
def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, model_name=None):
"""Chat function"""
# If model is not specified, use global configuration
if model_name is None:
model_name = MODEL_NAME
if not user_text and not audio_file:
yield history or [], "Please enter text or upload audio"
return
# Ensure history is a list and formatted correctly
history = history or []
clean_history = []
for item in history:
if isinstance(item, dict) and 'role' in item and 'content' in item:
clean_history.append(item)
elif hasattr(item, "role") and hasattr(item, "content"):
# Keep ChatMessage object
clean_history.append(item)
history = clean_history
# Process audio
audio_data_list = []
if audio_file:
audio_data_list = process_audio(audio_file)
messages = format_messages(system_prompt, history, user_text, audio_data_list)
if not messages:
yield history or [], "Invalid input"
return
# Debug: Print message format
debug_messages = []
for msg in messages:
if isinstance(msg, dict) and isinstance(msg.get("content"), list):
new_content = []
for item in msg["content"]:
if isinstance(item, dict) and item.get("type") == "input_audio":
item_copy = item.copy()
if "input_audio" in item_copy:
audio_info = item_copy["input_audio"].copy()
if "data" in audio_info:
audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]"
item_copy["input_audio"] = audio_info
new_content.append(item_copy)
else:
new_content.append(item)
msg_copy = msg.copy()
msg_copy["content"] = new_content
debug_messages.append(msg_copy)
else:
debug_messages.append(msg)
print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}")
# Update history with user message immediately
if audio_file:
# 1. Add audio message
history.append({"role": "user", "content": gr.Audio(audio_file)})
# 2. If text exists, add text message
if user_text:
history.append({"role": "user", "content": user_text})
else:
# Text only
history.append({"role": "user", "content": user_text})
# Add thinking placeholder
history.append(gr.ChatMessage(
role="assistant",
content="",
metadata={"title": "⏳ Thinking Process"}
))
yield history, "Generating..."
try:
with httpx.Client(base_url=API_BASE_URL, timeout=120) as client:
# Use client.stream for better streaming control
with client.stream("POST", "/chat/completions", json={
"model": model_name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
"repetition_penalty": 1.07,
"stop_token_ids": [151665]
}) as response:
if response.status_code != 200:
error_msg = f"❌ API Error {response.status_code}"
if response.status_code == 404:
error_msg += " - vLLM service not ready"
elif response.status_code == 400:
error_msg += " - Bad request"
elif response.status_code == 500:
error_msg += " - Model error"
yield history, error_msg
return
# Process streaming response
buffer = ""
is_thinking = True
print("[DEBUG] Start receiving stream...")
for line in response.iter_lines():
if not line:
continue
# Ensure line is string format
if isinstance(line, bytes):
line = line.decode('utf-8')
else:
line = str(line)
if line.startswith('data: '):
data_str = line[6:]
if data_str.strip() == '[DONE]':
print("[DEBUG] Stream finished [DONE]")
break
try:
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
if 'content' in delta:
content = delta['content']
buffer += content
if is_thinking:
if "</think>" in buffer:
is_thinking = False
parts = buffer.split("</think>", 1)
think_content = parts[0]
response_content = parts[1]
if think_content.startswith("<think>"):
think_content = think_content[len("<think>"):].strip()
# Update thinking message
history[-1].content = think_content
# Add response message
history.append({"role": "assistant", "content": response_content})
else:
# Update thinking message
current_think = buffer
if current_think.startswith("<think>"):
current_think = current_think[len("<think>"):]
history[-1].content = current_think
else:
# Already split, just update response message
parts = buffer.split("</think>", 1)
response_content = parts[1]
history[-1]["content"] = response_content
yield history, ""
except json.JSONDecodeError:
continue
except httpx.ConnectError:
yield history, "❌ Cannot connect to vLLM API"
except Exception as e:
yield history, f"❌ Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Step Audio R1") as demo:
gr.Markdown("# Step Audio R1 Chat")
with gr.Row():
# Left Configuration
with gr.Column(scale=1):
with gr.Accordion("Configuration", open=True):
system_prompt = gr.Textbox(
label="System Prompt",
lines=2,
value="你是一个语音助手,你有非常丰富的音频处理经验。"
)
max_tokens = gr.Slider(1, 7192, value=1024, label="Max Tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
status = gr.Textbox(label="Status", interactive=False)
# Right Chat
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History", height=450)
user_text = gr.Textbox(label="Input", lines=2, placeholder="Enter message...")
audio_file = gr.Audio(label="Audio", type="filepath", sources=["microphone", "upload"])
with gr.Row():
submit_btn = gr.Button("Send", variant="primary", scale=2)
clear_btn = gr.Button("Clear", scale=1)
submit_btn.click(
fn=chat,
inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, status]
)
clear_btn.click(
fn=lambda: ([], "", None),
outputs=[chatbot, user_text, audio_file]
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
# 更新全局模型名称
if args.model:
MODEL_NAME = args.model
print(f"启动Gradio: http://{args.host}:{args.port}")
print(f"API地址: {API_BASE_URL}")
print(f"模型: {MODEL_NAME}")
demo.launch(server_name=args.host, server_port=args.port, share=False)