moevis commited on
Commit
9b74786
·
verified ·
1 Parent(s): e04e1a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -40
app.py CHANGED
@@ -6,6 +6,8 @@ Step Audio R1 vLLM Gradio Interface
6
  import base64
7
  import json
8
  import os
 
 
9
 
10
  import gradio as gr
11
  import httpx
@@ -13,18 +15,44 @@ import httpx
13
  API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
15
 
16
- def encode_audio(audio_path):
17
- """编码音频为base64"""
 
 
 
18
  if not audio_path or not os.path.exists(audio_path):
19
- return None
 
20
  try:
21
- with open(audio_path, "rb") as f:
22
- return base64.b64encode(f.read()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
- print(f"[DEBUG] Audio error: {e}")
25
- return None
26
 
27
- def format_messages(system, history, user_text, audio_data=None, audio_format="wav"):
28
  """Format message list"""
29
  messages = []
30
  if system:
@@ -43,37 +71,40 @@ def format_messages(system, history, user_text, audio_data=None, audio_format="w
43
  messages.append({"role": item.role, "content": item.content})
44
 
45
  # 添加当前用户消息
46
- if user_text and audio_data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  messages.append({
48
  "role": "user",
49
- "content": [
50
- {
51
- "type": "input_audio",
52
- "input_audio": {
53
- "data": audio_data,
54
- "format": audio_format
55
- }
56
- },
57
- {
58
- "type": "text",
59
- "text": user_text
60
- }
61
- ]
62
  })
63
  elif user_text:
64
  messages.append({"role": "user", "content": user_text})
65
- elif audio_data:
 
 
 
 
 
 
 
 
 
66
  messages.append({
67
  "role": "user",
68
- "content": [
69
- {
70
- "type": "input_audio",
71
- "input_audio": {
72
- "data": audio_data,
73
- "format": audio_format
74
- }
75
- }
76
- ]
77
  })
78
 
79
  return messages
@@ -99,14 +130,11 @@ def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature,
99
  history = clean_history
100
 
101
  # Process audio
102
- audio_data = None
103
- audio_format = "wav"
104
  if audio_file:
105
- audio_data = encode_audio(audio_file)
106
- if audio_file.lower().endswith(".mp3"):
107
- audio_format = "mp3"
108
 
109
- messages = format_messages(system_prompt, history, user_text, audio_data, audio_format)
110
  if not messages:
111
  return history or [], "Invalid input"
112
 
@@ -249,8 +277,6 @@ with gr.Blocks(title="Step Audio R1") as demo:
249
  submit_btn = gr.Button("Send", variant="primary", scale=2)
250
  clear_btn = gr.Button("Clear", scale=1)
251
 
252
- # 事件绑定 - 函数将在启动时定义
253
- # 直接绑定 chat 函数;不要传递外部的 `model_to_use`,chat 使用默认的 `MODEL_NAME` 或内部参数
254
  submit_btn.click(
255
  fn=chat,
256
  inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p],
 
6
  import base64
7
  import json
8
  import os
9
+ import io
10
+ from pydub import AudioSegment
11
 
12
  import gradio as gr
13
  import httpx
 
15
  API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
16
  MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
17
 
18
+ def process_audio(audio_path):
19
+ """
20
+ Process audio: convert to wav, split if > 25s.
21
+ Returns a list of base64 encoded wav strings.
22
+ """
23
  if not audio_path or not os.path.exists(audio_path):
24
+ return []
25
+
26
  try:
27
+ # Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed)
28
+ audio = AudioSegment.from_file(audio_path)
29
+
30
+ # Split into chunks of 25 seconds (25000 ms)
31
+ chunk_length_ms = 25000
32
+ chunks = []
33
+
34
+ if len(audio) > chunk_length_ms:
35
+ for i in range(0, len(audio), chunk_length_ms):
36
+ chunk = audio[i:i + chunk_length_ms]
37
+ chunks.append(chunk)
38
+ else:
39
+ chunks.append(audio)
40
+
41
+ # Convert chunks to base64 wav
42
+ audio_data_list = []
43
+ for chunk in chunks:
44
+ buffer = io.BytesIO()
45
+ chunk.export(buffer, format="wav")
46
+ encoded = base64.b64encode(buffer.getvalue()).decode()
47
+ audio_data_list.append(encoded)
48
+
49
+ return audio_data_list
50
+
51
  except Exception as e:
52
+ print(f"[DEBUG] Audio processing error: {e}")
53
+ return []
54
 
55
+ def format_messages(system, history, user_text, audio_data_list=None):
56
  """Format message list"""
57
  messages = []
58
  if system:
 
71
  messages.append({"role": item.role, "content": item.content})
72
 
73
  # 添加当前用户消息
74
+ if user_text and audio_data_list:
75
+ content = []
76
+ for audio_data in audio_data_list:
77
+ content.append({
78
+ "type": "input_audio",
79
+ "input_audio": {
80
+ "data": audio_data,
81
+ "format": "wav"
82
+ }
83
+ })
84
+ content.append({
85
+ "type": "text",
86
+ "text": user_text
87
+ })
88
+
89
  messages.append({
90
  "role": "user",
91
+ "content": content
 
 
 
 
 
 
 
 
 
 
 
 
92
  })
93
  elif user_text:
94
  messages.append({"role": "user", "content": user_text})
95
+ elif audio_data_list:
96
+ content = []
97
+ for audio_data in audio_data_list:
98
+ content.append({
99
+ "type": "input_audio",
100
+ "input_audio": {
101
+ "data": audio_data,
102
+ "format": "wav"
103
+ }
104
+ })
105
  messages.append({
106
  "role": "user",
107
+ "content": content
 
 
 
 
 
 
 
 
108
  })
109
 
110
  return messages
 
130
  history = clean_history
131
 
132
  # Process audio
133
+ audio_data_list = []
 
134
  if audio_file:
135
+ audio_data_list = process_audio(audio_file)
 
 
136
 
137
+ messages = format_messages(system_prompt, history, user_text, audio_data_list)
138
  if not messages:
139
  return history or [], "Invalid input"
140
 
 
277
  submit_btn = gr.Button("Send", variant="primary", scale=2)
278
  clear_btn = gr.Button("Clear", scale=1)
279
 
 
 
280
  submit_btn.click(
281
  fn=chat,
282
  inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p],