""" Intelligent Hearing Aid - Audio Source Separation Interface Oticon-inspired clean design with proper UX feedback. """ import streamlit as st import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots import tempfile import os import io import zipfile import json import soundfile as sf import librosa # Page configuration st.set_page_config( page_title="Audio Source Separator | Oticon Audio Explorers 2026", page_icon="🎧", layout="wide", initial_sidebar_state="collapsed" ) st.markdown(""" """, unsafe_allow_html=True) def create_speaker_radar(sources_info: list, selected_idx: int) -> go.Figure: """Create a clean polar chart showing speaker positions.""" fig = go.Figure() # Modern color palette with magenta accent colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] # indigo, amber, emerald, violet selected_color = '#9a1b5a' # Oticon magenta # Head circle theta_head = np.linspace(0, 360, 100) fig.add_trace(go.Scatterpolar( r=[0.2] * 100, theta=theta_head, mode='lines', line=dict(color='#d1d5db', width=2), fill='toself', fillcolor='#f9fafb', hoverinfo='skip', showlegend=False )) # Plot speakers for i, info in enumerate(sources_info): is_selected = i == selected_idx direction = info.get('direction_deg') if direction is None: direction = 0.0 color = selected_color if is_selected else colors[i % len(colors)] gender = info.get('gender') or 'unknown' symbol = 'diamond' if gender == 'male' else 'circle' hover_text = ( f"Speaker {i+1}
" f"Direction: {direction:.0f}°
" f"Gender: {gender}
" f"Language: {(info.get('language') or '?').upper()}" ) fig.add_trace(go.Scatterpolar( r=[0.75], theta=[direction], mode='markers+text', marker=dict( size=30 if is_selected else 24, color=color, symbol=symbol, line=dict(color='white', width=3) ), text=[str(i+1)], textposition='middle center', textfont=dict(color='white', size=12, family='Arial'), name=f"Speaker {i+1}", hovertemplate=hover_text + "" )) fig.add_trace(go.Scatterpolar( r=[0.2, 0.75], theta=[direction, direction], mode='lines', line=dict(color=color, width=2 if is_selected else 1, dash='solid' if is_selected else 'dot'), hoverinfo='skip', showlegend=False )) fig.update_layout( polar=dict( radialaxis=dict(visible=False, range=[0, 1]), angularaxis=dict( tickmode='array', tickvals=[0, 90, 180, 270], ticktext=['Front', 'Right', 'Back', 'Left'], tickfont=dict(size=12, color='#6c757d'), direction='clockwise', rotation=90, gridcolor='#e9ecef', linecolor='#d1d5db' ), bgcolor='white' ), showlegend=False, paper_bgcolor='white', plot_bgcolor='white', margin=dict(l=60, r=60, t=40, b=40), height=380 ) return fig def create_waveform_plot(audio: np.ndarray, sr: int, color: str = '#9a1b5a') -> go.Figure: """Create a minimal waveform visualization.""" max_points = 3000 if len(audio) > max_points: step = len(audio) // max_points audio_plot = audio[::step] time = np.arange(len(audio_plot)) * step / sr else: audio_plot = audio time = np.arange(len(audio)) / sr # Convert hex to rgba for fill (Plotly doesn't support 8-digit hex) if color.startswith('#') and len(color) == 7: r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) fill_color = f'rgba({r},{g},{b},0.15)' else: fill_color = 'rgba(154,27,90,0.15)' # fallback magenta fig = go.Figure() fig.add_trace(go.Scatter( x=time, y=audio_plot, mode='lines', line=dict(color=color, width=1), fill='tozeroy', fillcolor=fill_color, hovertemplate='Time: %{x:.2f}s
Amplitude: %{y:.3f}' )) fig.update_layout( xaxis=dict( title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), gridcolor='#f0f2f5', tickfont=dict(color='#6c757d', size=10), zeroline=False ), yaxis=dict( title=dict(text='Amplitude', font=dict(size=11, color='#6c757d')), gridcolor='#f0f2f5', tickfont=dict(color='#6c757d', size=10), zeroline=True, zerolinecolor='#e9ecef' ), paper_bgcolor='white', plot_bgcolor='white', margin=dict(l=50, r=20, t=20, b=40), height=120 ) return fig def create_spectrogram(audio: np.ndarray, sr: int) -> go.Figure: """Create a clean spectrogram visualization.""" n_fft = 2048 hop_length = 512 D = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length) S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) times = librosa.times_like(S_db, sr=sr, hop_length=hop_length) freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) freq_mask = freqs <= 8000 S_db = S_db[freq_mask, :] freqs = freqs[freq_mask] # Magenta-based colorscale colorscale = [ [0, '#f8f9fa'], [0.3, '#e9d5df'], [0.6, '#c77da2'], [0.8, '#9a1b5a'], [1, '#5a1035'] ] fig = go.Figure(data=go.Heatmap( z=S_db, x=times, y=freqs, colorscale=colorscale, showscale=False, hovertemplate='Time: %{x:.2f}s
Freq: %{y:.0f}Hz
Power: %{z:.1f}dB' )) fig.update_layout( xaxis=dict( title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), tickfont=dict(color='#6c757d', size=10) ), yaxis=dict( title=dict(text='Frequency (Hz)', font=dict(size=11, color='#6c757d')), tickfont=dict(color='#6c757d', size=10) ), paper_bgcolor='white', plot_bgcolor='white', margin=dict(l=60, r=20, t=20, b=40), height=150 ) return fig def create_comparison_bars(sources_info: list, selected_idx: int) -> go.Figure: """Create a clean bar comparison chart.""" n = len(sources_info) speakers = [f"S{i+1}" for i in range(n)] colors = ['#9a1b5a' if i == selected_idx else '#d1d5db' for i in range(n)] fig = make_subplots( rows=1, cols=3, subplot_titles=('Selection Score', 'Pitch (Hz)', 'Energy'), horizontal_spacing=0.12 ) scores = [info.get('selection_score', 0) for info in sources_info] f0s = [info.get('f0_hz') or info.get('mean_f0_hz') or 0 for info in sources_info] energies = [info.get('energy', 0) * 100 for info in sources_info] for col, data in enumerate([(scores, 'Score'), (f0s, 'Hz'), (energies, 'Energy')], 1): fig.add_trace(go.Bar( x=speakers, y=data[0], marker_color=colors, showlegend=False, hovertemplate=f'Speaker %{{x}}
{data[1]}: %{{y:.1f}}' ), row=1, col=col) fig.update_layout( paper_bgcolor='white', plot_bgcolor='white', height=220, margin=dict(l=40, r=40, t=50, b=30), font=dict(color='#1a3a5c') ) fig.update_xaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5') fig.update_yaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5') for annotation in fig['layout']['annotations']: annotation['font'] = dict(color='#1a1a2e', size=12) return fig def create_timeline(sources_info: list, duration: float, selected_idx: int) -> go.Figure: """Create a simple audio timeline.""" fig = go.Figure() colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] # indigo, amber, emerald, violet for i, info in enumerate(sources_info): is_selected = i == selected_idx color = '#9a1b5a' if is_selected else colors[i % len(colors)] language = (info.get('language') or '?').upper() gender = (info.get('gender') or '?') fig.add_trace(go.Bar( x=[duration], y=[f"Speaker {i+1}"], orientation='h', marker=dict(color=color, opacity=1 if is_selected else 0.7), text=[f"{language} · {gender[0].upper()}"], textposition='inside', textfont=dict(color='white', size=11), hovertemplate=f"Speaker {i+1}
Duration: {duration:.1f}s", showlegend=False )) fig.update_layout( xaxis=dict( title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), range=[0, duration], gridcolor='#f0f2f5', tickfont=dict(color='#6c757d', size=10) ), yaxis=dict(tickfont=dict(color='#1a1a2e', size=11), gridcolor='#f0f2f5'), barmode='stack', paper_bgcolor='white', plot_bgcolor='white', height=180, margin=dict(l=100, r=20, t=20, b=40) ) return fig def process_audio( audio_path: str, approach: str = "ica", whisper_model: str = "small", hf_token: str | None = None, progress_callback=None, ) -> dict: """Process audio through the separation pipeline with progress updates.""" from approaches import get_approach output_dir = tempfile.mkdtemp() if progress_callback: progress_callback(0.05, "Loading audio file...") approach_class = get_approach(approach) pipeline = approach_class() if progress_callback: progress_callback(0.15, "Processing audio and separating sources...") # Run selected approach pipeline run_kwargs = { "input_file": audio_path, "output_dir": output_dir, "whisper_model": whisper_model, } if approach == "ica_deeplearning" and hf_token: run_kwargs["hf_token"] = hf_token pipeline_output = pipeline.run(**run_kwargs) results = pipeline_output.to_dict() if hasattr(pipeline_output, "to_dict") else dict(pipeline_output) if progress_callback: progress_callback(0.9, "Finalizing results...") results['output_dir'] = output_dir results['sources_audio'] = [] for i in range(results['n_speakers']): source_path = os.path.join(output_dir, f"source_{i+1}.wav") audio, _ = sf.read(source_path) results['sources_audio'].append(audio) original_audio, input_sr = sf.read(audio_path, always_2d=True) results['original_audio'] = original_audio[:, 0] results['sr'] = input_sr if progress_callback: progress_callback(1.0, "Complete!") return results def create_download_zip(results: dict) -> bytes: """Create ZIP with all outputs.""" zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: output_dir = results['output_dir'] for i in range(results['n_speakers']): source_path = os.path.join(output_dir, f"source_{i+1}.wav") if os.path.exists(source_path): zf.write(source_path, f"speaker_{i+1}.wav") output_path = os.path.join(output_dir, "output.wav") if os.path.exists(output_path): zf.write(output_path, "selected_speaker.wav") results_json = {k: v for k, v in results.items() if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']} zf.writestr("results.json", json.dumps(results_json, indent=2)) return zip_buffer.getvalue() def get_direction_label(direction: float) -> str: """Convert direction to human-readable label.""" if direction < 22.5 or direction > 337.5: return "Front" elif direction < 67.5: return "Front-Right" elif direction < 112.5: return "Right" elif direction < 157.5: return "Back-Right" elif direction < 202.5: return "Back" elif direction < 247.5: return "Back-Left" elif direction < 292.5: return "Left" else: return "Front-Left" def main(): """Main application.""" # Header st.markdown("""

OTICON Audio Explorers 2026

Audio Source Separator

Separate and analyze individual speakers from multi-channel hearing aid recordings

""", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) st.markdown("### Separation Approach") approach_options = ["ica", "frankenstein", "ica_deeplearning"] selected_approach = st.selectbox( "Choose approach", options=approach_options, index=0, format_func=lambda x: x.replace("_", "+").upper(), help="Select which pipeline variant to run. Default is ICA." ) hf_token = None if selected_approach == "ica_deeplearning": hf_token_input = st.text_input( "Hugging Face Token (optional)", type="password", help="Needed only if your ICA+DeepLearning run uses Pyannote diarization.", placeholder="hf_..." ) hf_token = hf_token_input.strip() or None # File upload section with clear label st.markdown("""

Upload Recording

Select a 4-channel WAV file from your hearing aid microphone array

""", unsafe_allow_html=True) uploaded_file = st.file_uploader( "Choose audio file", type=['wav'], help="4-channel WAV format (Left Front, Left Rear, Right Front, Right Rear)", label_visibility="collapsed" ) if uploaded_file is not None: with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp: tmp.write(uploaded_file.read()) tmp_path = tmp.name try: audio, sr = sf.read(tmp_path, always_2d=True) n_channels = audio.shape[1] duration = len(audio) / sr if n_channels != 4: st.error(f"Expected 4 channels, got {n_channels}. Please upload a valid hearing aid recording.") return st.markdown("
", unsafe_allow_html=True) # File info cards with proper labels st.markdown("### Recording Details") col1, col2, col3, col4 = st.columns(4) with col1: st.markdown(f"""

Duration

{duration:.1f}sec

""", unsafe_allow_html=True) with col2: st.markdown(f"""

Sample Rate

{sr//1000}kHz

""", unsafe_allow_html=True) with col3: st.markdown(f"""

Channels

{n_channels}

""", unsafe_allow_html=True) with col4: st.markdown(f"""

File Size

{uploaded_file.size / (1024*1024):.1f}MB

""", unsafe_allow_html=True) # Audio preview with label st.markdown("#### Preview") mono = np.mean(audio, axis=1) audio_bytes = io.BytesIO() sf.write(audio_bytes, mono, sr, format='WAV') st.audio(audio_bytes.getvalue(), format='audio/wav') st.markdown("
", unsafe_allow_html=True) # Initialize session state for processing if 'processing' not in st.session_state: st.session_state.processing = False # Process button with proper state management col_btn, col_space = st.columns([1, 2]) with col_btn: analyze_clicked = st.button( "Analyze Audio" if not st.session_state.processing else "Processing...", type="primary", disabled=st.session_state.processing, use_container_width=True ) if analyze_clicked and not st.session_state.processing: st.session_state.processing = True st.rerun() # Show processing UI if st.session_state.processing and 'results' not in st.session_state: st.markdown("""

Processing Audio

Separating sources and analyzing speakers...

""", unsafe_allow_html=True) progress_bar = st.progress(0) status_text = st.empty() def update_progress(value, text): progress_bar.progress(value) status_text.markdown(f"

{text}

", unsafe_allow_html=True) try: results = process_audio( tmp_path, approach=selected_approach, hf_token=hf_token, progress_callback=update_progress, ) st.session_state['results'] = results st.session_state.processing = False st.rerun() except Exception as e: st.session_state.processing = False st.error(f"Error processing audio: {str(e)}") return # Display results if 'results' in st.session_state: results = st.session_state['results'] sources_info = results['sources'] selected_idx = results['talker_of_interest'] - 1 st.divider() st.markdown("## Analysis Results") st.caption(f"Approach: {(results.get('approach') or selected_approach).replace('_', '+').upper()}") # Two column layout col_left, col_right = st.columns([1, 1]) with col_left: st.markdown("### Speaker Positions") st.markdown("

Spatial location of detected speakers relative to the listener

", unsafe_allow_html=True) radar_fig = create_speaker_radar(sources_info, selected_idx) st.plotly_chart(radar_fig, use_container_width=True) with col_right: st.markdown("#### Speaker Comparison") st.markdown("

Key metrics used for target speaker selection

", unsafe_allow_html=True) comparison_fig = create_comparison_bars(sources_info, selected_idx) st.plotly_chart(comparison_fig, use_container_width=True) st.markdown("#### Activity Timeline") st.markdown("

Speaker presence throughout the recording

", unsafe_allow_html=True) timeline_fig = create_timeline(sources_info, results['duration_seconds'], selected_idx) st.plotly_chart(timeline_fig, use_container_width=True) st.divider() st.markdown("## Separated Speakers") st.markdown("

Individual audio streams extracted from the recording

", unsafe_allow_html=True) # Speaker colors - matching radar colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] for i, info in enumerate(sources_info): is_selected = i == selected_idx color = '#9a1b5a' if is_selected else colors[i % len(colors)] # Speaker card card_class = "speaker-card selected" if is_selected else "speaker-card" badge = 'TARGET' if is_selected else '' st.markdown(f"""

Speaker {i+1}{badge}

""", unsafe_allow_html=True) # Metrics c1, c2, c3, c4 = st.columns(4) direction = info.get('direction_deg') if direction is None: c1.metric("Direction", "N/A") else: c1.metric("Direction", f"{direction:.0f}° ({get_direction_label(direction)})") c2.metric("Gender", (info.get('gender') or 'unknown').title()) c3.metric("Language", (info.get('language') or '?').upper()) score = info.get('selection_score') c4.metric("Score", f"{score:.1f}" if score is not None else "N/A") # Audio + download col_audio, col_dl = st.columns([4, 1]) source_path = os.path.join(results['output_dir'], f"source_{i+1}.wav") with col_audio: if os.path.exists(source_path): st.audio(source_path, format='audio/wav') with col_dl: if os.path.exists(source_path): with open(source_path, 'rb') as f: st.download_button( "Download", data=f.read(), file_name=f"speaker_{i+1}.wav", mime="audio/wav", key=f"dl_{i}" ) # Transcription transcription = info.get('transcription') or info.get('transcript') or '' if transcription: with st.expander("View Transcription"): st.write(f"

{transcription}

", unsafe_allow_html=True) # Waveform if i < len(results.get('sources_audio', [])): with st.expander("View Waveform & Spectrogram"): wf = create_waveform_plot(results['sources_audio'][i], results['sr'], color) st.plotly_chart(wf, use_container_width=True) spec = create_spectrogram(results['sources_audio'][i], results['sr']) st.plotly_chart(spec, use_container_width=True) # Download section st.divider() st.markdown("## Export") st.markdown("

Download separated audio files and analysis data

", unsafe_allow_html=True) c1, c2, c3 = st.columns(3) with c1: zip_data = create_download_zip(results) st.download_button( "Download All (ZIP)", data=zip_data, file_name="separated_audio.zip", mime="application/zip", use_container_width=True ) with c2: output_path = os.path.join(results['output_dir'], "output.wav") if os.path.exists(output_path): with open(output_path, 'rb') as f: st.download_button( "Download Target Speaker", data=f.read(), file_name="target_speaker.wav", mime="audio/wav", use_container_width=True ) with c3: results_json = {k: v for k, v in results.items() if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']} st.download_button( "Download Analysis (JSON)", data=json.dumps(results_json, indent=2), file_name="analysis.json", mime="application/json", use_container_width=True ) # Raw JSON with st.expander("View Raw Analysis Data"): display_results = {k: v for k, v in results.items() if k not in ["input_file", 'output_dir', 'sources_audio', 'original_audio', 'sr']} st.json(display_results) # Reset button st.markdown("
", unsafe_allow_html=True) if st.button("Analyze Another Recording"): for key in ['results', 'processing']: if key in st.session_state: del st.session_state[key] st.rerun() finally: if os.path.exists(tmp_path): os.unlink(tmp_path) if __name__ == "__main__": main()