| """ |
| Intelligent Hearing Aid - Audio Source Separation Interface |
| |
| Oticon-inspired clean design with proper UX feedback. |
| """ |
|
|
| import streamlit as st |
| import numpy as np |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
| import tempfile |
| import os |
| import io |
| import zipfile |
| import json |
| import soundfile as sf |
| import librosa |
|
|
| |
| st.set_page_config( |
| page_title="Audio Source Separator | Oticon Audio Explorers 2026", |
| page_icon="🎧", |
| layout="wide", |
| initial_sidebar_state="collapsed" |
| ) |
|
|
| st.markdown(""" |
| <style> |
| /* Base - clean light background */ |
| .stApp { |
| background: #f8f9fa; |
| } |
| |
| /* Headers - dark, professional */ |
| h1, h2, h3:not(.speaker-header) { |
| color: #1a1a2e !important; |
| font-weight: 600 !important; |
| } |
| |
| .speaker-header { |
| font-weight: 600 !important; |
| } |
| |
| /* ALL buttons - consistent magenta style */ |
| .stButton > button, |
| .stDownloadButton > button { |
| background: #9a1b5a !important; |
| color: white !important; |
| border: none !important; |
| border-radius: 8px !important; |
| font-weight: 500 !important; |
| padding: 0.6rem 1.2rem !important; |
| transition: background 0.2s ease !important; |
| } |
| |
| .stButton > button:hover, |
| .stDownloadButton > button:hover { |
| background: #7d1649 !important; |
| color: white !important; |
| border: none !important; |
| } |
| |
| .stButton > button:focus, |
| .stDownloadButton > button:focus { |
| background: #9a1b5a !important; |
| color: white !important; |
| box-shadow: 0 0 0 2px rgba(154, 27, 90, 0.3) !important; |
| } |
| |
| .stButton > button:active, |
| .stDownloadButton > button:active { |
| background: #6d1340 !important; |
| color: white !important; |
| } |
| |
| .stButton > button:disabled { |
| background: #d4a5bb !important; |
| color: white !important; |
| cursor: not-allowed !important; |
| } |
| |
| /* Selectbox - match magenta branding */ |
| [data-testid="stSelectbox"] label { |
| color: #1a1a2e !important; |
| font-weight: 500 !important; |
| } |
| |
| [data-testid="stSelectbox"] [data-baseweb="select"] > div { |
| border: 1px solid #d1d5db !important; |
| border-radius: 8px !important; |
| background: white !important; |
| transition: border-color 0.2s ease, box-shadow 0.2s ease !important; |
| } |
| |
| [data-testid="stSelectbox"] [data-baseweb="select"] > div:hover { |
| border-color: #9a1b5a !important; |
| } |
| |
| [data-testid="stSelectbox"] [data-baseweb="select"] > div:focus-within { |
| border-color: #9a1b5a !important; |
| box-shadow: 0 0 0 2px rgba(154, 27, 90, 0.25) !important; |
| } |
| |
| [data-testid="stSelectbox"] [data-baseweb="select"] span, |
| [data-testid="stSelectbox"] [data-baseweb="select"] input { |
| color: #1a1a2e !important; |
| } |
| |
| [data-testid="stSelectbox"] [data-baseweb="select"] svg { |
| fill: #9a1b5a !important; |
| } |
| |
| /* Ensure selected value text is visible */ |
| [data-testid="stSelectbox"] [data-baseweb="select"] > div > div, |
| [data-testid="stSelectbox"] [data-baseweb="select"] > div > div > div { |
| color: #1a1a2e !important; |
| } |
| |
| /* Dropdown menu styling (rendered in portal/popover) */ |
| div[data-baseweb="popover"] [role="listbox"] { |
| background: white !important; |
| border: 1px solid #e9ecef !important; |
| border-radius: 8px !important; |
| box-shadow: 0 8px 24px rgba(0, 0, 0, 0.08) !important; |
| } |
| |
| div[data-baseweb="popover"] [role="option"] { |
| background: white !important; |
| color: #1a1a2e !important; |
| } |
| |
| div[data-baseweb="popover"] [role="option"]:hover { |
| background: #fdf5f8 !important; |
| color: #7d1649 !important; |
| } |
| |
| div[data-baseweb="popover"] [role="option"][aria-selected="true"] { |
| background: #f4d7e4 !important; |
| color: #7d1649 !important; |
| font-weight: 600 !important; |
| } |
| |
| /* Text input (HF token) - match branding */ |
| [data-testid="stTextInput"] label { |
| color: #1a1a2e !important; |
| font-weight: 500 !important; |
| } |
| |
| [data-testid="stTextInput"] input { |
| background: white !important; |
| color: #1a1a2e !important; |
| border: 1px solid #d1d5db !important; |
| border-radius: 8px !important; |
| } |
| |
| [data-testid="stTextInput"] input::placeholder { |
| color: #9ca3af !important; |
| opacity: 1 !important; |
| } |
| |
| [data-testid="stTextInput"] input:hover { |
| border-color: #9a1b5a !important; |
| } |
| |
| [data-testid="stTextInput"] input:focus { |
| border-color: #9a1b5a !important; |
| box-shadow: 0 0 0 2px rgba(154, 27, 90, 0.25) !important; |
| outline: none !important; |
| } |
| |
| /* Cards */ |
| .info-card { |
| background: white; |
| border: 1px solid #e9ecef; |
| border-radius: 12px; |
| padding: 24px; |
| margin-bottom: 16px; |
| } |
| |
| .info-card h4 { |
| color: #6c757d; |
| margin: 0 0 8px 0; |
| font-size: 0.8rem; |
| font-weight: 500; |
| text-transform: uppercase; |
| letter-spacing: 0.5px; |
| } |
| |
| .info-card .value { |
| color: #1a1a2e; |
| font-size: 1.8rem; |
| font-weight: 600; |
| margin: 0; |
| } |
| |
| .info-card .unit { |
| color: #6c757d; |
| font-size: 0.9rem; |
| margin-left: 4px; |
| } |
| |
| /* Speaker cards */ |
| .speaker-card { |
| background: white; |
| border: 1px solid #e9ecef; |
| border-radius: 12px; |
| padding: 20px; |
| margin: 12px 0; |
| } |
| |
| .speaker-card.selected { |
| border: 2px solid #9a1b5a; |
| background: #fdf5f8; |
| } |
| |
| .speaker-badge { |
| display: inline-block; |
| background: #9a1b5a; |
| color: white; |
| font-size: 0.75rem; |
| font-weight: 600; |
| padding: 4px 10px; |
| border-radius: 12px; |
| margin-left: 8px; |
| } |
| |
| /* Progress section */ |
| .progress-section { |
| background: white; |
| border: 1px solid #e9ecef; |
| border-radius: 12px; |
| padding: 32px; |
| text-align: center; |
| } |
| |
| /* Metrics */ |
| [data-testid="stMetricLabel"] { |
| color: #6c757d !important; |
| font-size: 0.85rem !important; |
| text-transform: uppercase !important; |
| letter-spacing: 0.5px !important; |
| } |
| |
| [data-testid="stMetricValue"] { |
| color: #1a1a2e !important; |
| font-weight: 600 !important; |
| } |
| |
| /* Dividers */ |
| hr { |
| border-color: #e9ecef !important; |
| } |
| |
| /* Hide Streamlit branding */ |
| #MainMenu {visibility: hidden;} |
| footer {visibility: hidden;} |
| |
| /* EXPANDERS - Magenta header with white text */ |
| [data-testid="stExpander"] { |
| border: none !important; |
| border-radius: 8px !important; |
| overflow: hidden !important; |
| } |
| |
| [data-testid="stExpander"] > details { |
| border: none !important; |
| } |
| |
| [data-testid="stExpander"] > details > summary { |
| background: #9a1b5a !important; |
| color: white !important; |
| border-radius: 8px !important; |
| padding: 12px 16px !important; |
| font-weight: 500 !important; |
| } |
| |
| [data-testid="stExpander"] > details > summary:hover { |
| background: #7d1649 !important; |
| } |
| |
| [data-testid="stExpander"] > details > summary span, |
| [data-testid="stExpander"] > details > summary p { |
| color: white !important; |
| } |
| |
| [data-testid="stExpander"] > details > summary svg { |
| fill: white !important; |
| stroke: white !important; |
| } |
| |
| /* Expander content - white background */ |
| [data-testid="stExpander"] > details > div { |
| background: white !important; |
| border: 1px solid #e9ecef !important; |
| border-top: none !important; |
| border-radius: 0 0 8px 8px !important; |
| padding: 16px !important; |
| } |
| |
| /* Progress bar - magenta */ |
| .stProgress > div > div > div { |
| background: #9a1b5a !important; |
| } |
| |
| /* File uploader */ |
| [data-testid="stFileUploader"] { |
| background: white; |
| border: 2px dashed #d1d5db; |
| border-radius: 12px; |
| padding: 20px; |
| } |
| |
| [data-testid="stFileUploader"]:hover { |
| border-color: #9a1b5a; |
| } |
| |
| /* Audio player styling */ |
| audio { |
| border-radius: 8px; |
| } |
| |
| /* JSON viewer - force light theme */ |
| [data-testid="stJson"], |
| .stJson { |
| background: #f8f9fa !important; |
| border-radius: 8px !important; |
| padding: 16px !important; |
| } |
| |
| [data-testid="stJson"] *, |
| .stJson * { |
| background: transparent !important; |
| } |
| |
| /* JSON text colors - make keys visible */ |
| [data-testid="stJson"] span { |
| color: #1a1a2e !important; |
| } |
| |
| /* Override any dark theme in JSON */ |
| pre, code { |
| background: #f8f9fa !important; |
| color: #1a1a2e !important; |
| } |
| |
| /* Muted text class */ |
| .text-muted { |
| color: #6c757d !important; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
|
|
| def create_speaker_radar(sources_info: list, selected_idx: int) -> go.Figure: |
| """Create a clean polar chart showing speaker positions.""" |
| fig = go.Figure() |
|
|
| |
| colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] |
| selected_color = '#9a1b5a' |
|
|
| |
| theta_head = np.linspace(0, 360, 100) |
| fig.add_trace(go.Scatterpolar( |
| r=[0.2] * 100, |
| theta=theta_head, |
| mode='lines', |
| line=dict(color='#d1d5db', width=2), |
| fill='toself', |
| fillcolor='#f9fafb', |
| hoverinfo='skip', |
| showlegend=False |
| )) |
|
|
| |
| for i, info in enumerate(sources_info): |
| is_selected = i == selected_idx |
| direction = info.get('direction_deg') |
| if direction is None: |
| direction = 0.0 |
| color = selected_color if is_selected else colors[i % len(colors)] |
|
|
| gender = info.get('gender') or 'unknown' |
| symbol = 'diamond' if gender == 'male' else 'circle' |
|
|
| hover_text = ( |
| f"<b>Speaker {i+1}</b><br>" |
| f"Direction: {direction:.0f}°<br>" |
| f"Gender: {gender}<br>" |
| f"Language: {(info.get('language') or '?').upper()}" |
| ) |
|
|
| fig.add_trace(go.Scatterpolar( |
| r=[0.75], |
| theta=[direction], |
| mode='markers+text', |
| marker=dict( |
| size=30 if is_selected else 24, |
| color=color, |
| symbol=symbol, |
| line=dict(color='white', width=3) |
| ), |
| text=[str(i+1)], |
| textposition='middle center', |
| textfont=dict(color='white', size=12, family='Arial'), |
| name=f"Speaker {i+1}", |
| hovertemplate=hover_text + "<extra></extra>" |
| )) |
|
|
| fig.add_trace(go.Scatterpolar( |
| r=[0.2, 0.75], |
| theta=[direction, direction], |
| mode='lines', |
| line=dict(color=color, width=2 if is_selected else 1, dash='solid' if is_selected else 'dot'), |
| hoverinfo='skip', |
| showlegend=False |
| )) |
|
|
| fig.update_layout( |
| polar=dict( |
| radialaxis=dict(visible=False, range=[0, 1]), |
| angularaxis=dict( |
| tickmode='array', |
| tickvals=[0, 90, 180, 270], |
| ticktext=['Front', 'Right', 'Back', 'Left'], |
| tickfont=dict(size=12, color='#6c757d'), |
| direction='clockwise', |
| rotation=90, |
| gridcolor='#e9ecef', |
| linecolor='#d1d5db' |
| ), |
| bgcolor='white' |
| ), |
| showlegend=False, |
| paper_bgcolor='white', |
| plot_bgcolor='white', |
| margin=dict(l=60, r=60, t=40, b=40), |
| height=380 |
| ) |
|
|
| return fig |
|
|
|
|
| def create_waveform_plot(audio: np.ndarray, sr: int, color: str = '#9a1b5a') -> go.Figure: |
| """Create a minimal waveform visualization.""" |
| max_points = 3000 |
| if len(audio) > max_points: |
| step = len(audio) // max_points |
| audio_plot = audio[::step] |
| time = np.arange(len(audio_plot)) * step / sr |
| else: |
| audio_plot = audio |
| time = np.arange(len(audio)) / sr |
|
|
| |
| if color.startswith('#') and len(color) == 7: |
| r = int(color[1:3], 16) |
| g = int(color[3:5], 16) |
| b = int(color[5:7], 16) |
| fill_color = f'rgba({r},{g},{b},0.15)' |
| else: |
| fill_color = 'rgba(154,27,90,0.15)' |
|
|
| fig = go.Figure() |
| fig.add_trace(go.Scatter( |
| x=time, y=audio_plot, |
| mode='lines', |
| line=dict(color=color, width=1), |
| fill='tozeroy', |
| fillcolor=fill_color, |
| hovertemplate='Time: %{x:.2f}s<br>Amplitude: %{y:.3f}<extra></extra>' |
| )) |
|
|
| fig.update_layout( |
| xaxis=dict( |
| title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), |
| gridcolor='#f0f2f5', |
| tickfont=dict(color='#6c757d', size=10), |
| zeroline=False |
| ), |
| yaxis=dict( |
| title=dict(text='Amplitude', font=dict(size=11, color='#6c757d')), |
| gridcolor='#f0f2f5', |
| tickfont=dict(color='#6c757d', size=10), |
| zeroline=True, |
| zerolinecolor='#e9ecef' |
| ), |
| paper_bgcolor='white', |
| plot_bgcolor='white', |
| margin=dict(l=50, r=20, t=20, b=40), |
| height=120 |
| ) |
|
|
| return fig |
|
|
|
|
| def create_spectrogram(audio: np.ndarray, sr: int) -> go.Figure: |
| """Create a clean spectrogram visualization.""" |
| n_fft = 2048 |
| hop_length = 512 |
|
|
| D = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length) |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) |
|
|
| times = librosa.times_like(S_db, sr=sr, hop_length=hop_length) |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) |
|
|
| freq_mask = freqs <= 8000 |
| S_db = S_db[freq_mask, :] |
| freqs = freqs[freq_mask] |
|
|
| |
| colorscale = [ |
| [0, '#f8f9fa'], |
| [0.3, '#e9d5df'], |
| [0.6, '#c77da2'], |
| [0.8, '#9a1b5a'], |
| [1, '#5a1035'] |
| ] |
|
|
| fig = go.Figure(data=go.Heatmap( |
| z=S_db, x=times, y=freqs, |
| colorscale=colorscale, |
| showscale=False, |
| hovertemplate='Time: %{x:.2f}s<br>Freq: %{y:.0f}Hz<br>Power: %{z:.1f}dB<extra></extra>' |
| )) |
|
|
| fig.update_layout( |
| xaxis=dict( |
| title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), |
| tickfont=dict(color='#6c757d', size=10) |
| ), |
| yaxis=dict( |
| title=dict(text='Frequency (Hz)', font=dict(size=11, color='#6c757d')), |
| tickfont=dict(color='#6c757d', size=10) |
| ), |
| paper_bgcolor='white', |
| plot_bgcolor='white', |
| margin=dict(l=60, r=20, t=20, b=40), |
| height=150 |
| ) |
|
|
| return fig |
|
|
|
|
| def create_comparison_bars(sources_info: list, selected_idx: int) -> go.Figure: |
| """Create a clean bar comparison chart.""" |
| n = len(sources_info) |
| speakers = [f"S{i+1}" for i in range(n)] |
| colors = ['#9a1b5a' if i == selected_idx else '#d1d5db' for i in range(n)] |
|
|
| fig = make_subplots( |
| rows=1, cols=3, |
| subplot_titles=('Selection Score', 'Pitch (Hz)', 'Energy'), |
| horizontal_spacing=0.12 |
| ) |
|
|
| scores = [info.get('selection_score', 0) for info in sources_info] |
| f0s = [info.get('f0_hz') or info.get('mean_f0_hz') or 0 for info in sources_info] |
| energies = [info.get('energy', 0) * 100 for info in sources_info] |
|
|
| for col, data in enumerate([(scores, 'Score'), (f0s, 'Hz'), (energies, 'Energy')], 1): |
| fig.add_trace(go.Bar( |
| x=speakers, y=data[0], |
| marker_color=colors, |
| showlegend=False, |
| hovertemplate=f'Speaker %{{x}}<br>{data[1]}: %{{y:.1f}}<extra></extra>' |
| ), row=1, col=col) |
|
|
| fig.update_layout( |
| paper_bgcolor='white', |
| plot_bgcolor='white', |
| height=220, |
| margin=dict(l=40, r=40, t=50, b=30), |
| font=dict(color='#1a3a5c') |
| ) |
|
|
| fig.update_xaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5') |
| fig.update_yaxes(tickfont=dict(color='#6c757d', size=10), gridcolor='#f0f2f5') |
|
|
| for annotation in fig['layout']['annotations']: |
| annotation['font'] = dict(color='#1a1a2e', size=12) |
|
|
| return fig |
|
|
|
|
| def create_timeline(sources_info: list, duration: float, selected_idx: int) -> go.Figure: |
| """Create a simple audio timeline.""" |
| fig = go.Figure() |
| colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] |
|
|
| for i, info in enumerate(sources_info): |
| is_selected = i == selected_idx |
| color = '#9a1b5a' if is_selected else colors[i % len(colors)] |
|
|
| language = (info.get('language') or '?').upper() |
| gender = (info.get('gender') or '?') |
|
|
| fig.add_trace(go.Bar( |
| x=[duration], |
| y=[f"Speaker {i+1}"], |
| orientation='h', |
| marker=dict(color=color, opacity=1 if is_selected else 0.7), |
| text=[f"{language} · {gender[0].upper()}"], |
| textposition='inside', |
| textfont=dict(color='white', size=11), |
| hovertemplate=f"Speaker {i+1}<br>Duration: {duration:.1f}s<extra></extra>", |
| showlegend=False |
| )) |
|
|
| fig.update_layout( |
| xaxis=dict( |
| title=dict(text='Time (s)', font=dict(size=11, color='#6c757d')), |
| range=[0, duration], |
| gridcolor='#f0f2f5', |
| tickfont=dict(color='#6c757d', size=10) |
| ), |
| yaxis=dict(tickfont=dict(color='#1a1a2e', size=11), gridcolor='#f0f2f5'), |
| barmode='stack', |
| paper_bgcolor='white', |
| plot_bgcolor='white', |
| height=180, |
| margin=dict(l=100, r=20, t=20, b=40) |
| ) |
|
|
| return fig |
|
|
|
|
| def process_audio( |
| audio_path: str, |
| approach: str = "ica", |
| whisper_model: str = "small", |
| hf_token: str | None = None, |
| progress_callback=None, |
| ) -> dict: |
| """Process audio through the separation pipeline with progress updates.""" |
| from approaches import get_approach |
|
|
| output_dir = tempfile.mkdtemp() |
|
|
| if progress_callback: |
| progress_callback(0.05, "Loading audio file...") |
|
|
| approach_class = get_approach(approach) |
| pipeline = approach_class() |
|
|
| if progress_callback: |
| progress_callback(0.15, "Processing audio and separating sources...") |
|
|
| |
| run_kwargs = { |
| "input_file": audio_path, |
| "output_dir": output_dir, |
| "whisper_model": whisper_model, |
| } |
| if approach == "ica_deeplearning" and hf_token: |
| run_kwargs["hf_token"] = hf_token |
|
|
| pipeline_output = pipeline.run(**run_kwargs) |
| results = pipeline_output.to_dict() if hasattr(pipeline_output, "to_dict") else dict(pipeline_output) |
|
|
| if progress_callback: |
| progress_callback(0.9, "Finalizing results...") |
|
|
| results['output_dir'] = output_dir |
| results['sources_audio'] = [] |
|
|
| for i in range(results['n_speakers']): |
| source_path = os.path.join(output_dir, f"source_{i+1}.wav") |
| audio, _ = sf.read(source_path) |
| results['sources_audio'].append(audio) |
|
|
| original_audio, input_sr = sf.read(audio_path, always_2d=True) |
| results['original_audio'] = original_audio[:, 0] |
| results['sr'] = input_sr |
|
|
| if progress_callback: |
| progress_callback(1.0, "Complete!") |
|
|
| return results |
|
|
|
|
| def create_download_zip(results: dict) -> bytes: |
| """Create ZIP with all outputs.""" |
| zip_buffer = io.BytesIO() |
|
|
| with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: |
| output_dir = results['output_dir'] |
|
|
| for i in range(results['n_speakers']): |
| source_path = os.path.join(output_dir, f"source_{i+1}.wav") |
| if os.path.exists(source_path): |
| zf.write(source_path, f"speaker_{i+1}.wav") |
|
|
| output_path = os.path.join(output_dir, "output.wav") |
| if os.path.exists(output_path): |
| zf.write(output_path, "selected_speaker.wav") |
|
|
| results_json = {k: v for k, v in results.items() |
| if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']} |
| zf.writestr("results.json", json.dumps(results_json, indent=2)) |
|
|
| return zip_buffer.getvalue() |
|
|
|
|
| def get_direction_label(direction: float) -> str: |
| """Convert direction to human-readable label.""" |
| if direction < 22.5 or direction > 337.5: |
| return "Front" |
| elif direction < 67.5: |
| return "Front-Right" |
| elif direction < 112.5: |
| return "Right" |
| elif direction < 157.5: |
| return "Back-Right" |
| elif direction < 202.5: |
| return "Back" |
| elif direction < 247.5: |
| return "Back-Left" |
| elif direction < 292.5: |
| return "Left" |
| else: |
| return "Front-Left" |
|
|
|
|
| def main(): |
| """Main application.""" |
|
|
| |
| st.markdown(""" |
| <div style="text-align: center; padding: 40px 0 20px 0;"> |
| <p style="color: #9a1b5a; font-size: 0.9rem; font-weight: 600; letter-spacing: 1px; margin-bottom: 8px;"> |
| OTICON Audio Explorers 2026 |
| </p> |
| <h1 style="font-size: 2.4rem; margin: 0 0 12px 0; color: #1a1a2e;"> |
| Audio Source Separator |
| </h1> |
| <p style="color: #6c757d; font-size: 1.05rem; max-width: 500px; margin: 0 auto;"> |
| Separate and analyze individual speakers from multi-channel hearing aid recordings |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| st.markdown("<br>", unsafe_allow_html=True) |
|
|
| st.markdown("### Separation Approach") |
| approach_options = ["ica", "frankenstein", "ica_deeplearning"] |
| selected_approach = st.selectbox( |
| "Choose approach", |
| options=approach_options, |
| index=0, |
| format_func=lambda x: x.replace("_", "+").upper(), |
| help="Select which pipeline variant to run. Default is ICA." |
| ) |
|
|
| hf_token = None |
| if selected_approach == "ica_deeplearning": |
| hf_token_input = st.text_input( |
| "Hugging Face Token (optional)", |
| type="password", |
| help="Needed only if your ICA+DeepLearning run uses Pyannote diarization.", |
| placeholder="hf_..." |
| ) |
| hf_token = hf_token_input.strip() or None |
|
|
| |
| st.markdown(""" |
| <div class="info-card"> |
| <h4>Upload Recording</h4> |
| <p style="color: #6c757d; font-size: 0.9rem; margin: 0 0 16px 0;"> |
| Select a 4-channel WAV file from your hearing aid microphone array |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| uploaded_file = st.file_uploader( |
| "Choose audio file", |
| type=['wav'], |
| help="4-channel WAV format (Left Front, Left Rear, Right Front, Right Rear)", |
| label_visibility="collapsed" |
| ) |
|
|
| if uploaded_file is not None: |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp: |
| tmp.write(uploaded_file.read()) |
| tmp_path = tmp.name |
|
|
| try: |
| audio, sr = sf.read(tmp_path, always_2d=True) |
| n_channels = audio.shape[1] |
| duration = len(audio) / sr |
|
|
| if n_channels != 4: |
| st.error(f"Expected 4 channels, got {n_channels}. Please upload a valid hearing aid recording.") |
| return |
|
|
| st.markdown("<br>", unsafe_allow_html=True) |
|
|
| |
| st.markdown("### Recording Details") |
|
|
| col1, col2, col3, col4 = st.columns(4) |
|
|
| with col1: |
| st.markdown(f""" |
| <div class="info-card"> |
| <h4>Duration</h4> |
| <p class="value">{duration:.1f}<span class="unit">sec</span></p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| with col2: |
| st.markdown(f""" |
| <div class="info-card"> |
| <h4>Sample Rate</h4> |
| <p class="value">{sr//1000}<span class="unit">kHz</span></p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| with col3: |
| st.markdown(f""" |
| <div class="info-card"> |
| <h4>Channels</h4> |
| <p class="value">{n_channels}</p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| with col4: |
| st.markdown(f""" |
| <div class="info-card"> |
| <h4>File Size</h4> |
| <p class="value">{uploaded_file.size / (1024*1024):.1f}<span class="unit">MB</span></p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.markdown("#### Preview") |
| mono = np.mean(audio, axis=1) |
| audio_bytes = io.BytesIO() |
| sf.write(audio_bytes, mono, sr, format='WAV') |
| st.audio(audio_bytes.getvalue(), format='audio/wav') |
|
|
| st.markdown("<br>", unsafe_allow_html=True) |
|
|
| |
| if 'processing' not in st.session_state: |
| st.session_state.processing = False |
|
|
| |
| col_btn, col_space = st.columns([1, 2]) |
|
|
| with col_btn: |
| analyze_clicked = st.button( |
| "Analyze Audio" if not st.session_state.processing else "Processing...", |
| type="primary", |
| disabled=st.session_state.processing, |
| use_container_width=True |
| ) |
|
|
| if analyze_clicked and not st.session_state.processing: |
| st.session_state.processing = True |
| st.rerun() |
|
|
| |
| if st.session_state.processing and 'results' not in st.session_state: |
|
|
| st.markdown(""" |
| <div class="progress-section"> |
| <h3 style="margin: 0 0 16px 0;">Processing Audio</h3> |
| <p style="color: #6c757d; margin-bottom: 20px;"> |
| Separating sources and analyzing speakers... |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| progress_bar = st.progress(0) |
| status_text = st.empty() |
|
|
| def update_progress(value, text): |
| progress_bar.progress(value) |
| status_text.markdown(f"<p style='text-align: center; color: #6c757d;'>{text}</p>", unsafe_allow_html=True) |
|
|
| try: |
| results = process_audio( |
| tmp_path, |
| approach=selected_approach, |
| hf_token=hf_token, |
| progress_callback=update_progress, |
| ) |
| st.session_state['results'] = results |
| st.session_state.processing = False |
| st.rerun() |
| except Exception as e: |
| st.session_state.processing = False |
| st.error(f"Error processing audio: {str(e)}") |
| return |
|
|
| |
| if 'results' in st.session_state: |
| results = st.session_state['results'] |
| sources_info = results['sources'] |
| selected_idx = results['talker_of_interest'] - 1 |
|
|
| st.divider() |
| st.markdown("## Analysis Results") |
| st.caption(f"Approach: {(results.get('approach') or selected_approach).replace('_', '+').upper()}") |
|
|
| |
| col_left, col_right = st.columns([1, 1]) |
|
|
| with col_left: |
| st.markdown("### Speaker Positions") |
| st.markdown("<p style='color: #6c757d; font-size: 0.9rem;'>Spatial location of detected speakers relative to the listener</p>", unsafe_allow_html=True) |
| radar_fig = create_speaker_radar(sources_info, selected_idx) |
| st.plotly_chart(radar_fig, use_container_width=True) |
|
|
| with col_right: |
| st.markdown("#### Speaker Comparison") |
| st.markdown("<p style='color: #6c757d; font-size: 0.9rem;'>Key metrics used for target speaker selection</p>", unsafe_allow_html=True) |
| comparison_fig = create_comparison_bars(sources_info, selected_idx) |
| st.plotly_chart(comparison_fig, use_container_width=True) |
|
|
| st.markdown("#### Activity Timeline") |
| st.markdown("<p style='color: #6c757d; font-size: 0.9rem;'>Speaker presence throughout the recording</p>", unsafe_allow_html=True) |
| timeline_fig = create_timeline(sources_info, results['duration_seconds'], selected_idx) |
| st.plotly_chart(timeline_fig, use_container_width=True) |
|
|
| st.divider() |
| st.markdown("## Separated Speakers") |
| st.markdown("<p style='color: #6c757d;'>Individual audio streams extracted from the recording</p>", unsafe_allow_html=True) |
|
|
| |
| colors = ['#6366f1', '#f59e0b', '#10b981', '#8b5cf6'] |
|
|
| for i, info in enumerate(sources_info): |
| is_selected = i == selected_idx |
| color = '#9a1b5a' if is_selected else colors[i % len(colors)] |
|
|
| |
| card_class = "speaker-card selected" if is_selected else "speaker-card" |
| badge = '<span class="speaker-badge">TARGET</span>' if is_selected else '' |
|
|
| st.markdown(f""" |
| <div class="{card_class}"> |
| <h3 class="speaker-header" style="margin: 0; color: {color}; display: inline-flex; align-items: center;"> |
| Speaker {i+1}{badge} |
| </h3> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| |
| c1, c2, c3, c4 = st.columns(4) |
| direction = info.get('direction_deg') |
| if direction is None: |
| c1.metric("Direction", "N/A") |
| else: |
| c1.metric("Direction", f"{direction:.0f}° ({get_direction_label(direction)})") |
|
|
| c2.metric("Gender", (info.get('gender') or 'unknown').title()) |
| c3.metric("Language", (info.get('language') or '?').upper()) |
|
|
| score = info.get('selection_score') |
| c4.metric("Score", f"{score:.1f}" if score is not None else "N/A") |
|
|
| |
| col_audio, col_dl = st.columns([4, 1]) |
|
|
| source_path = os.path.join(results['output_dir'], f"source_{i+1}.wav") |
| with col_audio: |
| if os.path.exists(source_path): |
| st.audio(source_path, format='audio/wav') |
|
|
| with col_dl: |
| if os.path.exists(source_path): |
| with open(source_path, 'rb') as f: |
| st.download_button( |
| "Download", |
| data=f.read(), |
| file_name=f"speaker_{i+1}.wav", |
| mime="audio/wav", |
| key=f"dl_{i}" |
| ) |
|
|
| |
| transcription = info.get('transcription') or info.get('transcript') or '' |
| if transcription: |
| with st.expander("View Transcription"): |
| st.write(f"<p style='color: #1a1a2e;'>{transcription}</p>", unsafe_allow_html=True) |
|
|
| |
| if i < len(results.get('sources_audio', [])): |
| with st.expander("View Waveform & Spectrogram"): |
| wf = create_waveform_plot(results['sources_audio'][i], results['sr'], color) |
| st.plotly_chart(wf, use_container_width=True) |
|
|
| spec = create_spectrogram(results['sources_audio'][i], results['sr']) |
| st.plotly_chart(spec, use_container_width=True) |
|
|
| |
| st.divider() |
| st.markdown("## Export") |
| st.markdown("<p style='color: #6c757d;'>Download separated audio files and analysis data</p>", unsafe_allow_html=True) |
|
|
| c1, c2, c3 = st.columns(3) |
|
|
| with c1: |
| zip_data = create_download_zip(results) |
| st.download_button( |
| "Download All (ZIP)", |
| data=zip_data, |
| file_name="separated_audio.zip", |
| mime="application/zip", |
| use_container_width=True |
| ) |
|
|
| with c2: |
| output_path = os.path.join(results['output_dir'], "output.wav") |
| if os.path.exists(output_path): |
| with open(output_path, 'rb') as f: |
| st.download_button( |
| "Download Target Speaker", |
| data=f.read(), |
| file_name="target_speaker.wav", |
| mime="audio/wav", |
| use_container_width=True |
| ) |
|
|
| with c3: |
| results_json = {k: v for k, v in results.items() |
| if k not in ['output_dir', 'sources_audio', 'original_audio', 'sr']} |
| st.download_button( |
| "Download Analysis (JSON)", |
| data=json.dumps(results_json, indent=2), |
| file_name="analysis.json", |
| mime="application/json", |
| use_container_width=True |
| ) |
|
|
| |
| with st.expander("View Raw Analysis Data"): |
| display_results = {k: v for k, v in results.items() |
| if k not in ["input_file", 'output_dir', 'sources_audio', 'original_audio', 'sr']} |
| st.json(display_results) |
|
|
| |
| st.markdown("<br>", unsafe_allow_html=True) |
| if st.button("Analyze Another Recording"): |
| for key in ['results', 'processing']: |
| if key in st.session_state: |
| del st.session_state[key] |
| st.rerun() |
|
|
| finally: |
| if os.path.exists(tmp_path): |
| os.unlink(tmp_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|