import streamlit as st import os import torch import numpy as np import pandas as pd import plotly.express as px import scipy.io.wavfile import io BASE_PATH = 'Models' st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector") def process_audio(file_path, gain_factor): try: # 1. FIX: Check if file is actually a Git LFS pointer (text file) with open(file_path, 'rb') as f: header = f.read(50) if header.startswith(b'version https://git-lfs'): st.error(f"❌ **LFS Error:** `{os.path.basename(file_path)}` is a Git LFS pointer, not a WAV file. Run `git lfs pull` in your terminal.") return None sample_rate, data = scipy.io.wavfile.read(file_path) if data.dtype == np.int16: data = data.astype(np.float32) / 32768.0 elif data.dtype == np.int32: data = data.astype(np.float32) / 2147483648.0 data = data * gain_factor data = np.clip(data, -1.0, 1.0) data = (data * 32767).astype(np.int16) virtual_file = io.BytesIO() scipy.io.wavfile.write(virtual_file, sample_rate, data) return virtual_file except Exception as e: st.error(f"Error processing audio: {e}") return file_path def get_subdirs(path): if not os.path.exists(path): return [] return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] def load_spectrogram_interactive(pt_path, title="Spectrogram"): try: # 2. FIX: Added weights_only=False to fix PyTorch 2.6+ error spec_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) if spec_tensor.dim() == 4: spec_tensor = spec_tensor[0] if spec_tensor.dim() == 3: spec_data = spec_tensor.mean(dim=0).numpy() else: spec_data = spec_tensor.numpy() if spec_data.min() >= 0: spec_data = np.log1p(spec_data) fig = px.imshow( spec_data, origin='lower', aspect='auto', color_continuous_scale='Viridis', labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"), title=title ) fig.update_layout(margin=dict(l=0, r=0, t=30, b=0), height=300) return fig except Exception as e: st.error(f"Error loading spectrogram: {e}") return None def load_feature_map_interactive(pt_path): try: # 3. FIX: Added weights_only=False here as well feat_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) if feat_tensor.dim() == 4: feat_tensor = feat_tensor[0] mean_activation = feat_tensor.mean(dim=0).numpy() fig = px.imshow( mean_activation, origin='lower', aspect='auto', color_continuous_scale='Viridis', labels=dict(x="Time", y="Freq/Feature", color="Activation"), title=f"Mean Activation (Shape: {list(feat_tensor.shape)})" ) fig.update_layout(margin=dict(l=0, r=0, t=40, b=0)) return fig except Exception as e: return None st.title("🎵 Audio Source Separation Inspector") st.markdown(""" ### Model Interpretation Guide This tool helps you evaluate how well the model separates audio sources. * **Audio Quality:** Listen for "artifacts" (robotic sounds or clicking) in the Prediction compared to the Target. * **Spectrogram Clarity:** In the visuals below, distinct horizontal lines represent clear tones. Vertical smear usually indicates percussion or noise. * **Error Analysis:** If the Prediction looks "blurry" compared to the Target, the model is losing high-frequency details. """) if not os.path.exists(BASE_PATH): st.error(f"Models directory not found at {BASE_PATH}. Please ensure your data was uploaded correctly.") st.stop() models = get_subdirs(BASE_PATH) selected_model = st.sidebar.selectbox("Select Model", models) st.sidebar.markdown("### Audio Settings") volume_boost = st.sidebar.slider( "Volume Boost (Gain)", min_value=1.0, max_value=20.0, value=1.0, step=0.5, help="Digitally increases the amplitude of the audio signal." ) if selected_model: model_path = os.path.join(BASE_PATH, selected_model) artifacts_path = os.path.join(model_path, "test_artifacts") if os.path.exists(artifacts_path): samples = get_subdirs(artifacts_path) samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0) selected_sample = st.sidebar.selectbox("Select Sample ID", samples) if selected_sample: sample_path = os.path.join(artifacts_path, selected_sample) audio_dir = os.path.join(sample_path, "audio") specs_dir = os.path.join(sample_path, "specs") feats_dir = os.path.join(sample_path, "feats") all_files = os.listdir(audio_dir) target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")] classes = [f.replace("target_", "").replace(".wav", "") for f in target_files] selected_class = st.sidebar.selectbox("Focus Class", classes) tab1, tab2, tab3 = st.tabs(["🎧 Audio & Spectrograms", "🧠 Internal Activations", "📊 Model Metadata"]) with tab1: st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}") st.subheader("1. Mixture (Input)") st.markdown("The raw input containing all sound sources mixed together.") mix_audio = os.path.join(audio_dir, "mixture.wav") mix_spec = os.path.join(specs_dir, "mixture.pt") c1, c2 = st.columns([1, 3]) with c1: if os.path.exists(mix_audio): st.markdown("**Audio:**") processed_mix = process_audio(mix_audio, volume_boost) if processed_mix: st.audio(processed_mix, format='audio/wav') with c2: if os.path.exists(mix_spec): fig = load_spectrogram_interactive(mix_spec, title="Mixture Mel-Spectrogram") if fig: st.plotly_chart(fig, width='stretch') st.divider() st.subheader(f"2. Target: {selected_class}") st.markdown(f"**Interpretation:** This is the 'Ground Truth'. Look at the spectrogram structure here—this is the ideal output.") tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav") tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt") c1, c2 = st.columns([1, 3]) with c1: if os.path.exists(tgt_audio): st.markdown("**Audio:**") processed_tgt = process_audio(tgt_audio, volume_boost) if processed_tgt: st.audio(processed_tgt, format='audio/wav') with c2: if os.path.exists(tgt_spec): fig = load_spectrogram_interactive(tgt_spec, title=f"Target Mel-Spectrogram ({selected_class})") if fig: st.plotly_chart(fig, width='stretch') st.divider() st.subheader(f"3. Prediction: {selected_class}") st.markdown(f"**Interpretation:** Compare this to the Target above. If you see 'fuzziness' in the dark areas, the model is not silencing background noise correctly.") pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav") pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt") c1, c2 = st.columns([1, 3]) with c1: if os.path.exists(pred_audio): st.markdown("**Audio:**") processed_pred = process_audio(pred_audio, volume_boost) if processed_pred: st.audio(processed_pred, format='audio/wav') with c2: if os.path.exists(pred_spec): fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Mel-Spectrogram ({selected_class})") if fig: st.plotly_chart(fig, width='stretch') with tab2: st.header("Internal Feature Maps") st.markdown("These heatmaps visualize the neural network's internal state. Bright spots indicate features the model considers important for separation.") if os.path.exists(feats_dir): feat_files = sorted(os.listdir(feats_dir)) if feat_files: selected_layer = st.selectbox("Select Probed Layer", feat_files) if selected_layer: st.write(f"Layer: **{selected_layer.replace('.pt', '')}**") fig = load_feature_map_interactive(os.path.join(feats_dir, selected_layer)) if fig: st.plotly_chart(fig, width='stretch') else: st.warning("No feature maps found for this sample.") else: st.error("Features directory not found.") with tab3: st.header("Training and Testing Logs") st.markdown("Use these graphs to check for **Overfitting**. If Training Loss decreases but Test Metrics stagnate or drop, the model is memorizing data rather than learning general features.") c1, c2 = st.columns(2) with c1: results_csv = os.path.join(model_path, "test_results.csv") if os.path.exists(results_csv): st.subheader("Test Metrics") df = pd.read_csv(results_csv) x_axis = 'Batch_Index' if 'Batch_Index' in df.columns else df.index numeric_cols = df.select_dtypes(include=np.number).columns fig = px.line(df, title="Test Metrics", x=x_axis, y=numeric_cols) st.plotly_chart(fig, width='stretch') st.dataframe(df, width='stretch') else: st.info("No `test_results.csv` found.") with c2: loss_csv = os.path.join(model_path, "loss.csv") if os.path.exists(loss_csv): st.subheader("Training Loss") try: df_loss = pd.read_csv(loss_csv) x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index numeric_cols = df_loss.select_dtypes(include=np.number).columns fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves") st.plotly_chart(fig, width='stretch') st.dataframe(df_loss, width='stretch') except Exception as e: st.write("Could not parse `loss.csv`.", e) else: st.info("No `loss.csv` found.") else: st.warning(f"No 'test_artifacts' folder found in {selected_model}")