Spaces:
Build error
Build error
| import streamlit as st | |
| import json | |
| import os | |
| import asyncio | |
| import time | |
| from datetime import datetime, timedelta | |
| import nats | |
| from nats.errors import TimeoutError | |
| from dotenv import load_dotenv | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| load_dotenv() | |
| servers = os.environ.get("NATS_HOST", "nats://localhost:4222").split(",") | |
| user = os.environ.get("NATS_USER", "") | |
| password = os.environ.get("NATS_PWD", "") | |
| class LogType: | |
| VAD = "vad" | |
| ASR = "asr" | |
| def parse_vad_log(message): | |
| """Parse VAD log entries into structured data.""" | |
| parsed_data = { | |
| "vad": { | |
| 'timestamp': [], | |
| 'vad_score': [], | |
| 'smoothed_score': [], | |
| 'smoothed_score_ema': [], | |
| 'has_speech': [], | |
| 'has_speech_without_loudness': [], | |
| 'has_speech_without_negative_threshold': [], | |
| 'has_speech_without_loudness_and_negative_threshold': [], | |
| 'has_speech_with_ema': [], | |
| 'end_of_speech_raw': [], | |
| 'semantic_eos_prob': [], | |
| 'speech_region': [], | |
| 'loudness': [], | |
| 'reason': None, | |
| 'config': {} | |
| }, | |
| "asr": { | |
| "transcription": "", | |
| "confidence": None, | |
| "response": "" | |
| } | |
| } | |
| reason = None | |
| durations = [] | |
| for log_type, log_item in message.items(): | |
| if log_type == LogType.VAD: | |
| vad_data = parsed_data["vad"] | |
| log_entries = log_item.get("logg", []) | |
| vad_data['config'] = log_item.get('request', {}).get('config', {}) | |
| vad_data['when'] = log_item.get("timestamp") | |
| for entry in log_entries: | |
| # Split by comma to get timestamp and metrics | |
| parts = entry.split(',', 1) | |
| if len(parts) != 2: | |
| continue | |
| metrics_str = parts[1] | |
| # Parse metrics | |
| metrics = {} | |
| for metric in metrics_str.split('|'): | |
| metric = metric.strip() | |
| if '=' in metric: | |
| key, value = metric.split('=', 1) | |
| key = key.strip() | |
| value = value.strip() | |
| metrics[key] = value | |
| timestamp = float(metrics.get("vad.total_duration", 0)) | |
| if not timestamp: | |
| continue | |
| if not durations: | |
| durations.append(timestamp) | |
| if durations[-1] > timestamp: | |
| durations.append(timestamp) | |
| else: | |
| durations[-1] = timestamp | |
| # print(metrics) | |
| # Extract values | |
| vad_data['timestamp'].append(sum(durations)) | |
| vad_data['vad_score'].append(float(metrics.get('vad.vad_score', 0))) | |
| vad_data['smoothed_score'].append(float(metrics.get('vad.smoothed_score', 0))) | |
| vad_data['has_speech'].append(1 if metrics.get('vad.has_speech') == 'True' else 0) | |
| vad_data['has_speech_without_loudness'].append(1 if metrics.get('vad.alt.has_speech_without_loudness') == 'True' else 0) | |
| vad_data['has_speech_without_negative_threshold'].append(1 if metrics.get('vad.alt.has_speech_without_negative_threshold') == 'True' else 0) | |
| vad_data['has_speech_without_loudness_and_negative_threshold'].append(1 if metrics.get('vad.alt.has_speech_without_loudness_and_negative_threshold') == 'True' else 0) | |
| vad_data['end_of_speech_raw'].append(1 if metrics.get('vad.end_of_speech_raw') == 'True' else 0) | |
| vad_data['has_speech_with_ema'].append(1 if metrics.get('vad.alt.has_speech_with_ema') == 'True' else 0) | |
| vad_data["smoothed_score_ema"].append(float(metrics.get('vad.alt.smoothed_score_ema', 0))) | |
| try: | |
| vad_data["semantic_eos_prob"].append(float(metrics.get('vad.end_of_turn_prob', 0))) | |
| except: | |
| vad_data["semantic_eos_prob"].append(0) | |
| vad_data['loudness'].append(float(metrics.get('vad.loudness', 0))) | |
| if not vad_data.get("reason", None) and metrics.get("vad.reason", None): | |
| vad_data["reason"] = metrics.get("vad.reason") | |
| vad_data["speech_region"].append(1 if metrics.get('vad.speech_detected') == "True" and not vad_data["reason"] else 0) | |
| if log_type == LogType.ASR: | |
| asr_data = parsed_data["asr"] | |
| log_entries = log_item.get("logg", []) | |
| response = log_item.get("response", None) | |
| metrics = {} | |
| for entry in log_entries: | |
| parts = entry.split(',', 1) | |
| if len(parts) != 2: | |
| continue | |
| metrics_str = parts[1] | |
| if "asr" in metrics_str: | |
| try: | |
| key, value = metrics_str.split(",", 1) | |
| metrics[key] = value | |
| except: | |
| continue | |
| asr_data["transcription"] = metrics.get("asr_final", None) | |
| asr_data["confidence"] = float(metrics.get("asr_confidence", -1)) | |
| asr_data["response"] = response | |
| return parsed_data | |
| def create_vad_plot(parsed_data, message_index): | |
| """Create a Plotly figure with all VAD metrics.""" | |
| fig = make_subplots( | |
| rows=2, cols=1, | |
| specs=[[{"secondary_y": True}], [{"secondary_y": False}]] | |
| ) | |
| timestamps = parsed_data['timestamp'] | |
| # Add traces | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['vad_score'], | |
| name='VAD Score', | |
| mode='lines', | |
| line=dict(color='#1f77b4', width=2), | |
| marker=dict(size=4) | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['smoothed_score'], | |
| name='Smoothed Score', | |
| mode='lines', | |
| line=dict(color='#ff7f0e', width=2), | |
| marker=dict(size=4) | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['smoothed_score_ema'], | |
| name='Smoothed Score (EMA)', | |
| mode='lines', | |
| line=dict(color='#affa00', width=2, dash='dash'), | |
| marker=dict(size=4) | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['has_speech'], | |
| name='Has Speech', | |
| mode='lines', | |
| line=dict(color='#2ca02c', width=2), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['has_speech_without_loudness'], | |
| name='Has Speech (w/o loudness prefilter)', | |
| mode='lines', | |
| line=dict(color='#d62728', width=2, dash='dash'), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['has_speech_without_negative_threshold'], | |
| name='Has Speech (w/o negative threshold)', | |
| mode='lines', | |
| line=dict(color='#97a832', width=2, dash='dash'), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['has_speech_without_loudness_and_negative_threshold'], | |
| name='Has Speech (w/o loudness + negative threshold )', | |
| mode='lines', | |
| line=dict(color='#3234a8', width=2, dash='dash'), | |
| ) | |
| ) | |
| # Normalize loudness to [0, 1] range for better visualization | |
| # Typically loudness is in dB (negative values) | |
| # min_loudness = min(parsed_data['loudness']) if parsed_data['loudness'] else -100 | |
| # max_loudness = max(parsed_data['loudness']) if parsed_data['loudness'] else 0 | |
| # normalized_loudness = [ | |
| # (l - min_loudness) / (max_loudness - min_loudness) if max_loudness != min_loudness else 0 | |
| # for l in parsed_data['loudness'] | |
| # ] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['has_speech_with_ema'], | |
| name='Has Speech (with EMA)', | |
| mode='lines', | |
| line=dict(color='#f542c8', width=2, dash='dash'), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['end_of_speech_raw'], | |
| name='End of speech (raw)', | |
| mode='lines', | |
| line=dict(color='#d51238', width=2, dash='dash'), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['semantic_eos_prob'], | |
| name='EOS prob (semantic)', | |
| mode='lines', | |
| line=dict(color='#5914d9', width=2, dash='dash'), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data['speech_region'], | |
| name='Speech region', | |
| mode='lines', | |
| line=dict(color='#fffefa', width=2), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=timestamps, | |
| y=parsed_data["loudness"], | |
| # name=f'Loudness (normalized from {min_loudness:.1f} to {max_loudness:.1f} dB)', | |
| name="Loudness (dB)", | |
| mode='lines', | |
| line=dict(color='#9467bd', width=2), | |
| opacity=0.6, | |
| ), | |
| secondary_y=True | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=parsed_data["loudness"], | |
| y=parsed_data["has_speech"], | |
| # name=f'Loudness (normalized from {min_loudness:.1f} to {max_loudness:.1f} dB)', | |
| name="db to score", | |
| mode='markers', | |
| ), | |
| row=2, col=1 | |
| ) | |
| # Update layout | |
| fig.update_layout( | |
| title=f'VAD Metrics', | |
| xaxis_title='Time (seconds)', | |
| yaxis_title='Value', | |
| hovermode='closest', | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="right", | |
| x=1 | |
| ), | |
| height=500, | |
| margin=dict(t=100, b=50, l=50, r=50) | |
| ) | |
| fig.add_hrect( | |
| y0=parsed_data["config"].get("threshold", 0), | |
| y1=1.0, | |
| line_width=0, | |
| fillcolor="green", | |
| opacity=0.2, | |
| ) | |
| fig.add_hrect( | |
| y0=parsed_data["config"].get("negative_threshold", 0), | |
| y1=0.0, | |
| line_width=0, | |
| fillcolor="lightgray", | |
| opacity=0.2, | |
| ) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray') | |
| fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray') | |
| return fig | |
| async def capture_log(device_id, stop_event): | |
| """Capture logs from NATS and yield messages.""" | |
| # TODO: Implement NATS subscription to capture VAD logs | |
| # Expected log format: list of strings like: | |
| # "0.116,vad.vad_score=0.03 | vad.smoothed_score=0.03 | vad.has_speech=False | ..." | |
| # Each log entry should contain timestamp and VAD metrics separated by pipe characters | |
| nc = await nats.connect( | |
| servers=servers, | |
| user=user, | |
| password=password, | |
| ) | |
| message_queue = asyncio.Queue() | |
| sub_vad = await nc.subscribe("buddyos-vad") | |
| sub_nlu = await nc.subscribe("maika.ai.nlu") | |
| async def push_vad_message(): | |
| while not stop_event.is_set(): | |
| try: | |
| msg = await sub_vad.next_msg(timeout=0.1) | |
| await message_queue.put(msg) | |
| except TimeoutError: | |
| await asyncio.sleep(0.01) | |
| async def push_nlu_message(): | |
| while not stop_event.is_set(): | |
| try: | |
| msg = await sub_nlu.next_msg(timeout=0.1) | |
| await message_queue.put(msg) | |
| except TimeoutError: | |
| await asyncio.sleep(0.01) | |
| asyncio.create_task(push_vad_message()) | |
| asyncio.create_task(push_nlu_message()) | |
| try: | |
| while True: | |
| # while not stop_event.is_set(): | |
| try: | |
| # message = await sub_vad.next_msg(timeout=0.1) | |
| message = await message_queue.get() | |
| message = json.loads(message.data) | |
| if message.get("request", {}).get("device_id", {}) == device_id or message.get("request", {}).get("deviceId", {}) == device_id: | |
| if message.get("method") == "StreamVAD": | |
| # TODO: Extract VAD logs from the message | |
| # Expected structure: message should contain a field with VAD log entries | |
| # vad_logs = message.get("vad_logs", []) or message.get("logs", []) | |
| session_id = message.get("request", {}).get("session_id", None) | |
| vad_logs = message.get("logg", []) # Adjust field name as needed | |
| if vad_logs: | |
| print("Yield VAD") | |
| yield LogType.VAD, session_id, message | |
| elif message.get("method") == "AskMeAnyThing": | |
| session_id = message.get("request", {}).get("messageId", None) | |
| print("Yield ASR") | |
| yield LogType.ASR, session_id, message | |
| except TimeoutError: | |
| await asyncio.sleep(0.01) | |
| finally: | |
| await nc.close() | |
| # Initialize session state | |
| if "vad_logs" not in st.session_state: | |
| st.session_state.vad_logs = {} | |
| if "capturing" not in st.session_state: | |
| st.session_state.capturing = False | |
| if "stop_event" not in st.session_state: | |
| st.session_state.stop_event = None | |
| st.title("VAD Metrics Viz") | |
| # Device ID input | |
| device_id = st.sidebar.text_input("Device ID", value="AIMWL25350000006") | |
| # Control buttons | |
| if st.sidebar.button("Capture", | |
| disabled=st.session_state.capturing, | |
| use_container_width=True, | |
| type="secondary" | |
| ): | |
| if device_id: | |
| st.session_state.capturing = True | |
| st.session_state.vad_logs = {} | |
| st.session_state.stop_event = asyncio.Event() | |
| st.rerun() | |
| if st.sidebar.button( | |
| "Stop Capturing + Clear history", | |
| disabled=not st.session_state.capturing, | |
| use_container_width=True, | |
| type="primary", | |
| ): | |
| if st.session_state.stop_event: | |
| st.session_state.vad_logs = {} | |
| st.session_state.stop_event.set() | |
| st.session_state.capturing = False | |
| st.rerun() | |
| # Display VAD plots | |
| for session_id, message in st.session_state.vad_logs.items(): | |
| parsed_data = parse_vad_log(message) | |
| # if 'vad' in parsed_data and parsed_data["vad"]["timestamp"]: | |
| with st.expander("VAD Configuration"): | |
| st.code(parsed_data["vad"]["config"], wrap_lines=True) | |
| with st.expander(f"📊 VAD Log: session `{session_id}`", expanded=True): | |
| if 'vad' in parsed_data and parsed_data["vad"]["timestamp"]: | |
| fig = create_vad_plot(parsed_data["vad"], session_id) | |
| st.code(f"EOS reason: {parsed_data["vad"]["reason"]}") | |
| st.plotly_chart(fig, use_container_width=True) | |
| if 'asr' in parsed_data and parsed_data["asr"]["confidence"] is not None: | |
| st.markdown(f"Transcription: {parsed_data["asr"]["transcription"]}") | |
| st.markdown(f"Confidence {parsed_data["asr"]["confidence"]}") | |
| # else: | |
| # st.warning("No valid data to plot") | |
| # Capture logs if active | |
| if st.session_state.capturing and device_id: | |
| status_placeholder = st.sidebar.empty() | |
| status_placeholder.info(f"🔴 Capturing logs for device: {device_id}") | |
| async def run_capture(): | |
| async for log_type, session_id, vad_logs in capture_log(device_id, st.session_state.stop_event): | |
| if session_id: | |
| if session_id not in st.session_state.vad_logs: | |
| st.session_state.vad_logs[session_id] = { | |
| log_type: vad_logs | |
| } | |
| else: | |
| st.session_state.vad_logs[session_id][log_type] = vad_logs | |
| print(f"Add to session state: {log_type}") | |
| st.rerun() | |
| try: | |
| asyncio.run(run_capture()) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| st.session_state.capturing = False | |
| st.rerun() |