"""
Intelligent Hearing Aid - Audio Source Separation Interface
Oticon-inspired clean design with proper UX feedback.
"""
import streamlit as st
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tempfile
import os
import io
import zipfile
import json
import soundfile as sf
import librosa
# Page configuration
st.set_page_config(
page_title="Audio Source Separator | Oticon Audio Explorers 2026",
page_icon="🎧",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
""", unsafe_allow_html=True)
def create_speaker_radar(sources_info: list, selected_idx: int) -> go.Figure:
"""Create a clean polar chart showing speaker positions."""
fig = go.Figure()
# Modern color palette with magenta accent
colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] # indigo, amber, emerald, violet
selected_color = '#9a1b5a' # Oticon magenta
# Head circle
theta_head = np.linspace(0, 360, 100)
fig.add_trace(go.Scatterpolar(
r=[0.2] * 100,
theta=theta_head,
mode='lines',
line=dict(color='#d1d5db', width=2),
fill='toself',
fillcolor='#f9fafb',
hoverinfo='skip',
showlegend=False
))
# Plot speakers
for i, info in enumerate(sources_info):
is_selected = i == selected_idx
direction = info.get('direction_deg')
if direction is None:
direction = 0.0
color = selected_color if is_selected else colors[i % len(colors)]
gender = info.get('gender') or 'unknown'
symbol = 'diamond' if gender == 'male' else 'circle'
hover_text = (
f"Speaker {i+1}
"
f"Direction: {direction:.0f}°
"
f"Gender: {gender}
"
f"Language: {(info.get('language') or '?').upper()}"
)
fig.add_trace(go.Scatterpolar(
r=[0.75],
theta=[direction],
mode='markers+text',
marker=dict(
size=30 if is_selected else 24,
color=color,
symbol=symbol,
line=dict(color='white', width=3)
),
text=[str(i+1)],
textposition='middle center',
textfont=dict(color='white', size=12, family='Arial'),
name=f"Speaker {i+1}",
hovertemplate=hover_text + ""
))
fig.add_trace(go.Scatterpolar(
r=[0.2, 0.75],
theta=[direction, direction],
mode='lines',
line=dict(color=color, width=2 if is_selected else 1, dash='solid' if is_selected else 'dot'),
hoverinfo='skip',
showlegend=False
))
fig.update_layout(
polar=dict(
radialaxis=dict(visible=False, range=[0, 1]),
angularaxis=dict(
tickmode='array',
tickvals=[0, 90, 180, 270],
ticktext=['Front', 'Right', 'Back', 'Left'],
tickfont=dict(size=12, color='#6c757d'),
direction='clockwise',
rotation=90,
gridcolor='#e9ecef',
linecolor='#d1d5db'
),
bgcolor='white'
),
showlegend=False,
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(l=60, r=60, t=40, b=40),
height=380
)
return fig
def create_waveform_plot(audio: np.ndarray, sr: int, color: str = '#9a1b5a') -> go.Figure:
"""Create a minimal waveform visualization."""
max_points = 3000
if len(audio) > max_points:
step = len(audio) // max_points
audio_plot = audio[::step]
time = np.arange(len(audio_plot)) * step / sr
else:
audio_plot = audio
time = np.arange(len(audio)) / sr
# Convert hex to rgba for fill (Plotly doesn't support 8-digit hex)
if color.startswith('#') and len(color) == 7:
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
fill_color = f'rgba({r},{g},{b},0.15)'
else:
fill_color = 'rgba(154,27,90,0.15)' # fallback magenta
fig = go.Figure()
fig.add_trace(go.Scatter(
x=time, y=audio_plot,
mode='lines',
line=dict(color=color, width=1),
fill='tozeroy',
fillcolor=fill_color,
hovertemplate='Time: %{x:.2f}s
Amplitude: %{y:.3f}'
))
fig.update_layout(
xaxis=dict(
title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')),
gridcolor='#f0f2f5',
tickfont=dict(color='#6c757d', size=10),
zeroline=False
),
yaxis=dict(
title=dict(text='Amplitude', font=dict(size=11, color='#6c757d')),
gridcolor='#f0f2f5',
tickfont=dict(color='#6c757d', size=10),
zeroline=True,
zerolinecolor='#e9ecef'
),
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(l=50, r=20, t=20, b=40),
height=120
)
return fig
def create_spectrogram(audio: np.ndarray, sr: int) -> go.Figure:
"""Create a clean spectrogram visualization."""
n_fft = 2048
hop_length = 512
D = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
times = librosa.times_like(S_db, sr=sr, hop_length=hop_length)
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
freq_mask = freqs <= 8000
S_db = S_db[freq_mask, :]
freqs = freqs[freq_mask]
# Magenta-based colorscale
colorscale = [
[0, '#f8f9fa'],
[0.3, '#e9d5df'],
[0.6, '#c77da2'],
[0.8, '#9a1b5a'],
[1, '#5a1035']
]
fig = go.Figure(data=go.Heatmap(
z=S_db, x=times, y=freqs,
colorscale=colorscale,
showscale=False,
hovertemplate='Time: %{x:.2f}s
Freq: %{y:.0f}Hz
Power: %{z:.1f}dB'
))
fig.update_layout(
xaxis=dict(
title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')),
tickfont=dict(color='#6c757d', size=10)
),
yaxis=dict(
title=dict(text='Frequency (Hz)', font=dict(size=11, color='#6c757d')),
tickfont=dict(color='#6c757d', size=10)
),
paper_bgcolor='white',
plot_bgcolor='white',
margin=dict(l=60, r=20, t=20, b=40),
height=150
)
return fig
def create_comparison_bars(sources_info: list, selected_idx: int) -> go.Figure:
"""Create a clean bar comparison chart."""
n = len(sources_info)
speakers = [f"S{i+1}" for i in range(n)]
colors = ['#9a1b5a' if i == selected_idx else '#d1d5db' for i in range(n)]
fig = make_subplots(
rows=1, cols=3,
subplot_titles=('Selection Score', 'Pitch (Hz)', 'Energy'),
horizontal_spacing=0.12
)
scores = [info.get('selection_score', 0) for info in sources_info]
f0s = [info.get('f0_hz') or info.get('mean_f0_hz') or 0 for info in sources_info]
energies = [info.get('energy', 0) * 100 for info in sources_info]
for col, data in enumerate([(scores, 'Score'), (f0s, 'Hz'), (energies, 'Energy')], 1):
fig.add_trace(go.Bar(
x=speakers, y=data[0],
marker_color=colors,
showlegend=False,
hovertemplate=f'Speaker %{{x}}
{data[1]}: %{{y:.1f}}'
), row=1, col=col)
fig.update_layout(
paper_bgcolor='white',
plot_bgcolor='white',
height=220,
margin=dict(l=40, r=40, t=50, b=30),
font=dict(color='#1a3a5c')
)
fig.update_xaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5')
fig.update_yaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5')
for annotation in fig['layout']['annotations']:
annotation['font'] = dict(color='#1a1a2e', size=12)
return fig
def create_timeline(sources_info: list, duration: float, selected_idx: int) -> go.Figure:
"""Create a simple audio timeline."""
fig = go.Figure()
colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] # indigo, amber, emerald, violet
for i, info in enumerate(sources_info):
is_selected = i == selected_idx
color = '#9a1b5a' if is_selected else colors[i % len(colors)]
language = (info.get('language') or '?').upper()
gender = (info.get('gender') or '?')
fig.add_trace(go.Bar(
x=[duration],
y=[f"Speaker {i+1}"],
orientation='h',
marker=dict(color=color, opacity=1 if is_selected else 0.7),
text=[f"{language} · {gender[0].upper()}"],
textposition='inside',
textfont=dict(color='white', size=11),
hovertemplate=f"Speaker {i+1}
Duration: {duration:.1f}s",
showlegend=False
))
fig.update_layout(
xaxis=dict(
title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')),
range=[0, duration],
gridcolor='#f0f2f5',
tickfont=dict(color='#6c757d', size=10)
),
yaxis=dict(tickfont=dict(color='#1a1a2e', size=11), gridcolor='#f0f2f5'),
barmode='stack',
paper_bgcolor='white',
plot_bgcolor='white',
height=180,
margin=dict(l=100, r=20, t=20, b=40)
)
return fig
def process_audio(
audio_path: str,
approach: str = "ica",
whisper_model: str = "small",
hf_token: str | None = None,
progress_callback=None,
) -> dict:
"""Process audio through the separation pipeline with progress updates."""
from approaches import get_approach
output_dir = tempfile.mkdtemp()
if progress_callback:
progress_callback(0.05, "Loading audio file...")
approach_class = get_approach(approach)
pipeline = approach_class()
if progress_callback:
progress_callback(0.15, "Processing audio and separating sources...")
# Run selected approach pipeline
run_kwargs = {
"input_file": audio_path,
"output_dir": output_dir,
"whisper_model": whisper_model,
}
if approach == "ica_deeplearning" and hf_token:
run_kwargs["hf_token"] = hf_token
pipeline_output = pipeline.run(**run_kwargs)
results = pipeline_output.to_dict() if hasattr(pipeline_output, "to_dict") else dict(pipeline_output)
if progress_callback:
progress_callback(0.9, "Finalizing results...")
results['output_dir'] = output_dir
results['sources_audio'] = []
for i in range(results['n_speakers']):
source_path = os.path.join(output_dir, f"source_{i+1}.wav")
audio, _ = sf.read(source_path)
results['sources_audio'].append(audio)
original_audio, input_sr = sf.read(audio_path, always_2d=True)
results['original_audio'] = original_audio[:, 0]
results['sr'] = input_sr
if progress_callback:
progress_callback(1.0, "Complete!")
return results
def create_download_zip(results: dict) -> bytes:
"""Create ZIP with all outputs."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
output_dir = results['output_dir']
for i in range(results['n_speakers']):
source_path = os.path.join(output_dir, f"source_{i+1}.wav")
if os.path.exists(source_path):
zf.write(source_path, f"speaker_{i+1}.wav")
output_path = os.path.join(output_dir, "output.wav")
if os.path.exists(output_path):
zf.write(output_path, "selected_speaker.wav")
results_json = {k: v for k, v in results.items()
if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']}
zf.writestr("results.json", json.dumps(results_json, indent=2))
return zip_buffer.getvalue()
def get_direction_label(direction: float) -> str:
"""Convert direction to human-readable label."""
if direction < 22.5 or direction > 337.5:
return "Front"
elif direction < 67.5:
return "Front-Right"
elif direction < 112.5:
return "Right"
elif direction < 157.5:
return "Back-Right"
elif direction < 202.5:
return "Back"
elif direction < 247.5:
return "Back-Left"
elif direction < 292.5:
return "Left"
else:
return "Front-Left"
def main():
"""Main application."""
# Header
st.markdown("""
OTICON Audio Explorers 2026
Audio Source Separator
Separate and analyze individual speakers from multi-channel hearing aid recordings
""", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
st.markdown("### Separation Approach")
approach_options = ["ica", "frankenstein", "ica_deeplearning"]
selected_approach = st.selectbox(
"Choose approach",
options=approach_options,
index=0,
format_func=lambda x: x.replace("_", "+").upper(),
help="Select which pipeline variant to run. Default is ICA."
)
hf_token = None
if selected_approach == "ica_deeplearning":
hf_token_input = st.text_input(
"Hugging Face Token (optional)",
type="password",
help="Needed only if your ICA+DeepLearning run uses Pyannote diarization.",
placeholder="hf_..."
)
hf_token = hf_token_input.strip() or None
# File upload section with clear label
st.markdown("""
Upload Recording
Select a 4-channel WAV file from your hearing aid microphone array
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Choose audio file",
type=['wav'],
help="4-channel WAV format (Left Front, Left Rear, Right Front, Right Rear)",
label_visibility="collapsed"
)
if uploaded_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
tmp.write(uploaded_file.read())
tmp_path = tmp.name
try:
audio, sr = sf.read(tmp_path, always_2d=True)
n_channels = audio.shape[1]
duration = len(audio) / sr
if n_channels != 4:
st.error(f"Expected 4 channels, got {n_channels}. Please upload a valid hearing aid recording.")
return
st.markdown("
", unsafe_allow_html=True)
# File info cards with proper labels
st.markdown("### Recording Details")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown(f"""
Duration
{duration:.1f}sec
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
Sample Rate
{sr//1000}kHz
""", unsafe_allow_html=True)
with col3:
st.markdown(f"""
""", unsafe_allow_html=True)
with col4:
st.markdown(f"""
File Size
{uploaded_file.size / (1024*1024):.1f}MB
""", unsafe_allow_html=True)
# Audio preview with label
st.markdown("#### Preview")
mono = np.mean(audio, axis=1)
audio_bytes = io.BytesIO()
sf.write(audio_bytes, mono, sr, format='WAV')
st.audio(audio_bytes.getvalue(), format='audio/wav')
st.markdown("
", unsafe_allow_html=True)
# Initialize session state for processing
if 'processing' not in st.session_state:
st.session_state.processing = False
# Process button with proper state management
col_btn, col_space = st.columns([1, 2])
with col_btn:
analyze_clicked = st.button(
"Analyze Audio" if not st.session_state.processing else "Processing...",
type="primary",
disabled=st.session_state.processing,
use_container_width=True
)
if analyze_clicked and not st.session_state.processing:
st.session_state.processing = True
st.rerun()
# Show processing UI
if st.session_state.processing and 'results' not in st.session_state:
st.markdown("""
Processing Audio
Separating sources and analyzing speakers...
""", unsafe_allow_html=True)
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(value, text):
progress_bar.progress(value)
status_text.markdown(f"{text}
", unsafe_allow_html=True)
try:
results = process_audio(
tmp_path,
approach=selected_approach,
hf_token=hf_token,
progress_callback=update_progress,
)
st.session_state['results'] = results
st.session_state.processing = False
st.rerun()
except Exception as e:
st.session_state.processing = False
st.error(f"Error processing audio: {str(e)}")
return
# Display results
if 'results' in st.session_state:
results = st.session_state['results']
sources_info = results['sources']
selected_idx = results['talker_of_interest'] - 1
st.divider()
st.markdown("## Analysis Results")
st.caption(f"Approach: {(results.get('approach') or selected_approach).replace('_', '+').upper()}")
# Two column layout
col_left, col_right = st.columns([1, 1])
with col_left:
st.markdown("### Speaker Positions")
st.markdown("Spatial location of detected speakers relative to the listener
", unsafe_allow_html=True)
radar_fig = create_speaker_radar(sources_info, selected_idx)
st.plotly_chart(radar_fig, use_container_width=True)
with col_right:
st.markdown("#### Speaker Comparison")
st.markdown("Key metrics used for target speaker selection
", unsafe_allow_html=True)
comparison_fig = create_comparison_bars(sources_info, selected_idx)
st.plotly_chart(comparison_fig, use_container_width=True)
st.markdown("#### Activity Timeline")
st.markdown("Speaker presence throughout the recording
", unsafe_allow_html=True)
timeline_fig = create_timeline(sources_info, results['duration_seconds'], selected_idx)
st.plotly_chart(timeline_fig, use_container_width=True)
st.divider()
st.markdown("## Separated Speakers")
st.markdown("Individual audio streams extracted from the recording
", unsafe_allow_html=True)
# Speaker colors - matching radar
colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6']
for i, info in enumerate(sources_info):
is_selected = i == selected_idx
color = '#9a1b5a' if is_selected else colors[i % len(colors)]
# Speaker card
card_class = "speaker-card selected" if is_selected else "speaker-card"
badge = 'TARGET' if is_selected else ''
st.markdown(f"""
""", unsafe_allow_html=True)
# Metrics
c1, c2, c3, c4 = st.columns(4)
direction = info.get('direction_deg')
if direction is None:
c1.metric("Direction", "N/A")
else:
c1.metric("Direction", f"{direction:.0f}° ({get_direction_label(direction)})")
c2.metric("Gender", (info.get('gender') or 'unknown').title())
c3.metric("Language", (info.get('language') or '?').upper())
score = info.get('selection_score')
c4.metric("Score", f"{score:.1f}" if score is not None else "N/A")
# Audio + download
col_audio, col_dl = st.columns([4, 1])
source_path = os.path.join(results['output_dir'], f"source_{i+1}.wav")
with col_audio:
if os.path.exists(source_path):
st.audio(source_path, format='audio/wav')
with col_dl:
if os.path.exists(source_path):
with open(source_path, 'rb') as f:
st.download_button(
"Download",
data=f.read(),
file_name=f"speaker_{i+1}.wav",
mime="audio/wav",
key=f"dl_{i}"
)
# Transcription
transcription = info.get('transcription') or info.get('transcript') or ''
if transcription:
with st.expander("View Transcription"):
st.write(f"{transcription}
", unsafe_allow_html=True)
# Waveform
if i < len(results.get('sources_audio', [])):
with st.expander("View Waveform & Spectrogram"):
wf = create_waveform_plot(results['sources_audio'][i], results['sr'], color)
st.plotly_chart(wf, use_container_width=True)
spec = create_spectrogram(results['sources_audio'][i], results['sr'])
st.plotly_chart(spec, use_container_width=True)
# Download section
st.divider()
st.markdown("## Export")
st.markdown("Download separated audio files and analysis data
", unsafe_allow_html=True)
c1, c2, c3 = st.columns(3)
with c1:
zip_data = create_download_zip(results)
st.download_button(
"Download All (ZIP)",
data=zip_data,
file_name="separated_audio.zip",
mime="application/zip",
use_container_width=True
)
with c2:
output_path = os.path.join(results['output_dir'], "output.wav")
if os.path.exists(output_path):
with open(output_path, 'rb') as f:
st.download_button(
"Download Target Speaker",
data=f.read(),
file_name="target_speaker.wav",
mime="audio/wav",
use_container_width=True
)
with c3:
results_json = {k: v for k, v in results.items()
if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']}
st.download_button(
"Download Analysis (JSON)",
data=json.dumps(results_json, indent=2),
file_name="analysis.json",
mime="application/json",
use_container_width=True
)
# Raw JSON
with st.expander("View Raw Analysis Data"):
display_results = {k: v for k, v in results.items()
if k not in ["input_file", 'output_dir', 'sources_audio', 'original_audio', 'sr']}
st.json(display_results)
# Reset button
st.markdown("
", unsafe_allow_html=True)
if st.button("Analyze Another Recording"):
for key in ['results', 'processing']:
if key in st.session_state:
del st.session_state[key]
st.rerun()
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
if __name__ == "__main__":
main()