duyngtr
init code
5db4f1d
import streamlit as st
import json
import os
import asyncio
import time
from datetime import datetime, timedelta
import nats
from nats.errors import TimeoutError
from dotenv import load_dotenv
import plotly.graph_objects as go
from plotly.subplots import make_subplots
load_dotenv()
servers = os.environ.get("NATS_HOST", "nats://localhost:4222").split(",")
user = os.environ.get("NATS_USER", "")
password = os.environ.get("NATS_PWD", "")
class LogType:
VAD = "vad"
ASR = "asr"
def parse_vad_log(message):
"""Parse VAD log entries into structured data."""
parsed_data = {
"vad": {
'timestamp': [],
'vad_score': [],
'smoothed_score': [],
'smoothed_score_ema': [],
'has_speech': [],
'has_speech_without_loudness': [],
'has_speech_without_negative_threshold': [],
'has_speech_without_loudness_and_negative_threshold': [],
'has_speech_with_ema': [],
'end_of_speech_raw': [],
'semantic_eos_prob': [],
'speech_region': [],
'loudness': [],
'reason': None,
'config': {}
},
"asr": {
"transcription": "",
"confidence": None,
"response": ""
}
}
reason = None
durations = []
for log_type, log_item in message.items():
if log_type == LogType.VAD:
vad_data = parsed_data["vad"]
log_entries = log_item.get("logg", [])
vad_data['config'] = log_item.get('request', {}).get('config', {})
vad_data['when'] = log_item.get("timestamp")
for entry in log_entries:
# Split by comma to get timestamp and metrics
parts = entry.split(',', 1)
if len(parts) != 2:
continue
metrics_str = parts[1]
# Parse metrics
metrics = {}
for metric in metrics_str.split('|'):
metric = metric.strip()
if '=' in metric:
key, value = metric.split('=', 1)
key = key.strip()
value = value.strip()
metrics[key] = value
timestamp = float(metrics.get("vad.total_duration", 0))
if not timestamp:
continue
if not durations:
durations.append(timestamp)
if durations[-1] > timestamp:
durations.append(timestamp)
else:
durations[-1] = timestamp
# print(metrics)
# Extract values
vad_data['timestamp'].append(sum(durations))
vad_data['vad_score'].append(float(metrics.get('vad.vad_score', 0)))
vad_data['smoothed_score'].append(float(metrics.get('vad.smoothed_score', 0)))
vad_data['has_speech'].append(1 if metrics.get('vad.has_speech') == 'True' else 0)
vad_data['has_speech_without_loudness'].append(1 if metrics.get('vad.alt.has_speech_without_loudness') == 'True' else 0)
vad_data['has_speech_without_negative_threshold'].append(1 if metrics.get('vad.alt.has_speech_without_negative_threshold') == 'True' else 0)
vad_data['has_speech_without_loudness_and_negative_threshold'].append(1 if metrics.get('vad.alt.has_speech_without_loudness_and_negative_threshold') == 'True' else 0)
vad_data['end_of_speech_raw'].append(1 if metrics.get('vad.end_of_speech_raw') == 'True' else 0)
vad_data['has_speech_with_ema'].append(1 if metrics.get('vad.alt.has_speech_with_ema') == 'True' else 0)
vad_data["smoothed_score_ema"].append(float(metrics.get('vad.alt.smoothed_score_ema', 0)))
try:
vad_data["semantic_eos_prob"].append(float(metrics.get('vad.end_of_turn_prob', 0)))
except:
vad_data["semantic_eos_prob"].append(0)
vad_data['loudness'].append(float(metrics.get('vad.loudness', 0)))
if not vad_data.get("reason", None) and metrics.get("vad.reason", None):
vad_data["reason"] = metrics.get("vad.reason")
vad_data["speech_region"].append(1 if metrics.get('vad.speech_detected') == "True" and not vad_data["reason"] else 0)
if log_type == LogType.ASR:
asr_data = parsed_data["asr"]
log_entries = log_item.get("logg", [])
response = log_item.get("response", None)
metrics = {}
for entry in log_entries:
parts = entry.split(',', 1)
if len(parts) != 2:
continue
metrics_str = parts[1]
if "asr" in metrics_str:
try:
key, value = metrics_str.split(",", 1)
metrics[key] = value
except:
continue
asr_data["transcription"] = metrics.get("asr_final", None)
asr_data["confidence"] = float(metrics.get("asr_confidence", -1))
asr_data["response"] = response
return parsed_data
def create_vad_plot(parsed_data, message_index):
"""Create a Plotly figure with all VAD metrics."""
fig = make_subplots(
rows=2, cols=1,
specs=[[{"secondary_y": True}], [{"secondary_y": False}]]
)
timestamps = parsed_data['timestamp']
# Add traces
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['vad_score'],
name='VAD Score',
mode='lines',
line=dict(color='#1f77b4', width=2),
marker=dict(size=4)
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['smoothed_score'],
name='Smoothed Score',
mode='lines',
line=dict(color='#ff7f0e', width=2),
marker=dict(size=4)
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['smoothed_score_ema'],
name='Smoothed Score (EMA)',
mode='lines',
line=dict(color='#affa00', width=2, dash='dash'),
marker=dict(size=4)
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['has_speech'],
name='Has Speech',
mode='lines',
line=dict(color='#2ca02c', width=2),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['has_speech_without_loudness'],
name='Has Speech (w/o loudness prefilter)',
mode='lines',
line=dict(color='#d62728', width=2, dash='dash'),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['has_speech_without_negative_threshold'],
name='Has Speech (w/o negative threshold)',
mode='lines',
line=dict(color='#97a832', width=2, dash='dash'),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['has_speech_without_loudness_and_negative_threshold'],
name='Has Speech (w/o loudness + negative threshold )',
mode='lines',
line=dict(color='#3234a8', width=2, dash='dash'),
)
)
# Normalize loudness to [0, 1] range for better visualization
# Typically loudness is in dB (negative values)
# min_loudness = min(parsed_data['loudness']) if parsed_data['loudness'] else -100
# max_loudness = max(parsed_data['loudness']) if parsed_data['loudness'] else 0
# normalized_loudness = [
# (l - min_loudness) / (max_loudness - min_loudness) if max_loudness != min_loudness else 0
# for l in parsed_data['loudness']
# ]
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['has_speech_with_ema'],
name='Has Speech (with EMA)',
mode='lines',
line=dict(color='#f542c8', width=2, dash='dash'),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['end_of_speech_raw'],
name='End of speech (raw)',
mode='lines',
line=dict(color='#d51238', width=2, dash='dash'),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['semantic_eos_prob'],
name='EOS prob (semantic)',
mode='lines',
line=dict(color='#5914d9', width=2, dash='dash'),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data['speech_region'],
name='Speech region',
mode='lines',
line=dict(color='#fffefa', width=2),
)
)
fig.add_trace(
go.Scatter(
x=timestamps,
y=parsed_data["loudness"],
# name=f'Loudness (normalized from {min_loudness:.1f} to {max_loudness:.1f} dB)',
name="Loudness (dB)",
mode='lines',
line=dict(color='#9467bd', width=2),
opacity=0.6,
),
secondary_y=True
)
fig.add_trace(
go.Scatter(
x=parsed_data["loudness"],
y=parsed_data["has_speech"],
# name=f'Loudness (normalized from {min_loudness:.1f} to {max_loudness:.1f} dB)',
name="db to score",
mode='markers',
),
row=2, col=1
)
# Update layout
fig.update_layout(
title=f'VAD Metrics',
xaxis_title='Time (seconds)',
yaxis_title='Value',
hovermode='closest',
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
height=500,
margin=dict(t=100, b=50, l=50, r=50)
)
fig.add_hrect(
y0=parsed_data["config"].get("threshold", 0),
y1=1.0,
line_width=0,
fillcolor="green",
opacity=0.2,
)
fig.add_hrect(
y0=parsed_data["config"].get("negative_threshold", 0),
y1=0.0,
line_width=0,
fillcolor="lightgray",
opacity=0.2,
)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
return fig
async def capture_log(device_id, stop_event):
"""Capture logs from NATS and yield messages."""
# TODO: Implement NATS subscription to capture VAD logs
# Expected log format: list of strings like:
# "0.116,vad.vad_score=0.03 | vad.smoothed_score=0.03 | vad.has_speech=False | ..."
# Each log entry should contain timestamp and VAD metrics separated by pipe characters
nc = await nats.connect(
servers=servers,
user=user,
password=password,
)
message_queue = asyncio.Queue()
sub_vad = await nc.subscribe("buddyos-vad")
sub_nlu = await nc.subscribe("maika.ai.nlu")
async def push_vad_message():
while not stop_event.is_set():
try:
msg = await sub_vad.next_msg(timeout=0.1)
await message_queue.put(msg)
except TimeoutError:
await asyncio.sleep(0.01)
async def push_nlu_message():
while not stop_event.is_set():
try:
msg = await sub_nlu.next_msg(timeout=0.1)
await message_queue.put(msg)
except TimeoutError:
await asyncio.sleep(0.01)
asyncio.create_task(push_vad_message())
asyncio.create_task(push_nlu_message())
try:
while True:
# while not stop_event.is_set():
try:
# message = await sub_vad.next_msg(timeout=0.1)
message = await message_queue.get()
message = json.loads(message.data)
if message.get("request", {}).get("device_id", {}) == device_id or message.get("request", {}).get("deviceId", {}) == device_id:
if message.get("method") == "StreamVAD":
# TODO: Extract VAD logs from the message
# Expected structure: message should contain a field with VAD log entries
# vad_logs = message.get("vad_logs", []) or message.get("logs", [])
session_id = message.get("request", {}).get("session_id", None)
vad_logs = message.get("logg", []) # Adjust field name as needed
if vad_logs:
print("Yield VAD")
yield LogType.VAD, session_id, message
elif message.get("method") == "AskMeAnyThing":
session_id = message.get("request", {}).get("messageId", None)
print("Yield ASR")
yield LogType.ASR, session_id, message
except TimeoutError:
await asyncio.sleep(0.01)
finally:
await nc.close()
# Initialize session state
if "vad_logs" not in st.session_state:
st.session_state.vad_logs = {}
if "capturing" not in st.session_state:
st.session_state.capturing = False
if "stop_event" not in st.session_state:
st.session_state.stop_event = None
st.title("VAD Metrics Viz")
# Device ID input
device_id = st.sidebar.text_input("Device ID", value="AIMWL25350000006")
# Control buttons
if st.sidebar.button("Capture",
disabled=st.session_state.capturing,
use_container_width=True,
type="secondary"
):
if device_id:
st.session_state.capturing = True
st.session_state.vad_logs = {}
st.session_state.stop_event = asyncio.Event()
st.rerun()
if st.sidebar.button(
"Stop Capturing + Clear history",
disabled=not st.session_state.capturing,
use_container_width=True,
type="primary",
):
if st.session_state.stop_event:
st.session_state.vad_logs = {}
st.session_state.stop_event.set()
st.session_state.capturing = False
st.rerun()
# Display VAD plots
for session_id, message in st.session_state.vad_logs.items():
parsed_data = parse_vad_log(message)
# if 'vad' in parsed_data and parsed_data["vad"]["timestamp"]:
with st.expander("VAD Configuration"):
st.code(parsed_data["vad"]["config"], wrap_lines=True)
with st.expander(f"📊 VAD Log: session `{session_id}`", expanded=True):
if 'vad' in parsed_data and parsed_data["vad"]["timestamp"]:
fig = create_vad_plot(parsed_data["vad"], session_id)
st.code(f"EOS reason: {parsed_data["vad"]["reason"]}")
st.plotly_chart(fig, use_container_width=True)
if 'asr' in parsed_data and parsed_data["asr"]["confidence"] is not None:
st.markdown(f"Transcription: {parsed_data["asr"]["transcription"]}")
st.markdown(f"Confidence {parsed_data["asr"]["confidence"]}")
# else:
# st.warning("No valid data to plot")
# Capture logs if active
if st.session_state.capturing and device_id:
status_placeholder = st.sidebar.empty()
status_placeholder.info(f"🔴 Capturing logs for device: {device_id}")
async def run_capture():
async for log_type, session_id, vad_logs in capture_log(device_id, st.session_state.stop_event):
if session_id:
if session_id not in st.session_state.vad_logs:
st.session_state.vad_logs[session_id] = {
log_type: vad_logs
}
else:
st.session_state.vad_logs[session_id][log_type] = vad_logs
print(f"Add to session state: {log_type}")
st.rerun()
try:
asyncio.run(run_capture())
except Exception as e:
st.error(f"Error: {str(e)}")
st.session_state.capturing = False
st.rerun()