Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import requests | |
| import base64 | |
| import tempfile | |
| import os | |
| import time | |
| import traceback | |
| import librosa | |
| from pydub import AudioSegment | |
| from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration | |
| import av | |
| from utils.vad import get_speech_timestamps, collect_chunks, VadOptions | |
| API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat") | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| def run_vad(audio, sr): | |
| _st = time.time() | |
| try: | |
| audio = audio.astype(np.float32) / 32768.0 | |
| sampling_rate = 16000 | |
| if sr != sampling_rate: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate) | |
| vad_parameters = {} | |
| vad_parameters = VadOptions(**vad_parameters) | |
| speech_chunks = get_speech_timestamps(audio, vad_parameters) | |
| audio = collect_chunks(audio, speech_chunks) | |
| duration_after_vad = audio.shape[0] / sampling_rate | |
| if sr != sampling_rate: | |
| # resample to original sampling rate | |
| vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr) | |
| else: | |
| vad_audio = audio | |
| vad_audio = np.round(vad_audio * 32768.0).astype(np.int16) | |
| vad_audio_bytes = vad_audio.tobytes() | |
| return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4) | |
| except Exception as e: | |
| msg = f"[asr vad error] audio_len: {len(audio)/(sr):.3f} s, trace: {traceback.format_exc()}" | |
| print(msg) | |
| return -1, audio.tobytes(), round(time.time() - _st, 4) | |
| def save_tmp_audio(audio_bytes): | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
| file_name = tmpfile.name | |
| audio = AudioSegment( | |
| data=audio_bytes, | |
| sample_width=2, | |
| frame_rate=16000, | |
| channels=1, | |
| ) | |
| audio.export(file_name, format="wav") | |
| return file_name | |
| def main(): | |
| st.title("Chat Mini-Omni Demo") | |
| status = st.empty() | |
| if "audio_buffer" not in st.session_state: | |
| st.session_state.audio_buffer = [] | |
| webrtc_ctx = webrtc_streamer( | |
| key="speech-to-text", | |
| mode=WebRtcMode.SENDONLY, | |
| audio_receiver_size=1024, | |
| rtc_configuration=RTCConfiguration( | |
| {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| ), | |
| media_stream_constraints={"video": False, "audio": True}, | |
| ) | |
| if webrtc_ctx.audio_receiver: | |
| while True: | |
| try: | |
| audio_frame = webrtc_ctx.audio_receiver.get_frame(timeout=1) | |
| sound_chunk = np.frombuffer(audio_frame.to_ndarray(), dtype="int16") | |
| st.session_state.audio_buffer.extend(sound_chunk) | |
| if len(st.session_state.audio_buffer) >= 16000: | |
| duration_after_vad, vad_audio_bytes, vad_time = run_vad( | |
| np.array(st.session_state.audio_buffer), 16000 | |
| ) | |
| st.session_state.audio_buffer = [] | |
| if duration_after_vad > 0: | |
| st.session_state.messages.append( | |
| {"role": "user", "content": "User audio"} | |
| ) | |
| file_name = save_tmp_audio(vad_audio_bytes) | |
| st.audio(file_name, format="audio/wav") | |
| response = requests.post(API_URL, data=vad_audio_bytes) | |
| assistant_audio_bytes = response.content | |
| assistant_file_name = save_tmp_audio(assistant_audio_bytes) | |
| st.audio(assistant_file_name, format="audio/wav") | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": "Assistant response"} | |
| ) | |
| except Exception as e: | |
| print(f"Error in audio processing: {e}") | |
| break | |
| if st.button("Process Audio"): | |
| if st.session_state.audio_buffer: | |
| duration_after_vad, vad_audio_bytes, vad_time = run_vad( | |
| np.array(st.session_state.audio_buffer), 16000 | |
| ) | |
| st.session_state.messages.append({"role": "user", "content": "User audio"}) | |
| file_name = save_tmp_audio(vad_audio_bytes) | |
| st.audio(file_name, format="audio/wav") | |
| response = requests.post(API_URL, data=vad_audio_bytes) | |
| assistant_audio_bytes = response.content | |
| assistant_file_name = save_tmp_audio(assistant_audio_bytes) | |
| st.audio(assistant_file_name, format="audio/wav") | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": "Assistant response"} | |
| ) | |
| st.session_state.audio_buffer = [] | |
| if st.session_state.messages: | |
| for message in st.session_state.messages: | |
| if message["role"] == "user": | |
| st.write(f"User: {message['content']}") | |
| else: | |
| st.write(f"Assistant: {message['content']}") | |
| if __name__ == "__main__": | |
| main() | |