|
|
import streamlit as st |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import plotly.express as px |
|
|
import scipy.io.wavfile |
|
|
import io |
|
|
|
|
|
BASE_PATH = 'Models' |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector") |
|
|
|
|
|
def process_audio(file_path, gain_factor): |
|
|
try: |
|
|
|
|
|
with open(file_path, 'rb') as f: |
|
|
header = f.read(50) |
|
|
if header.startswith(b'version https://git-lfs'): |
|
|
st.error(f"❌ **LFS Error:** `{os.path.basename(file_path)}` is a Git LFS pointer, not a WAV file. Run `git lfs pull` in your terminal.") |
|
|
return None |
|
|
|
|
|
sample_rate, data = scipy.io.wavfile.read(file_path) |
|
|
|
|
|
if data.dtype == np.int16: |
|
|
data = data.astype(np.float32) / 32768.0 |
|
|
elif data.dtype == np.int32: |
|
|
data = data.astype(np.float32) / 2147483648.0 |
|
|
|
|
|
data = data * gain_factor |
|
|
|
|
|
data = np.clip(data, -1.0, 1.0) |
|
|
|
|
|
data = (data * 32767).astype(np.int16) |
|
|
|
|
|
virtual_file = io.BytesIO() |
|
|
scipy.io.wavfile.write(virtual_file, sample_rate, data) |
|
|
return virtual_file |
|
|
except Exception as e: |
|
|
st.error(f"Error processing audio: {e}") |
|
|
return file_path |
|
|
|
|
|
def get_subdirs(path): |
|
|
if not os.path.exists(path): |
|
|
return [] |
|
|
return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] |
|
|
|
|
|
def load_spectrogram_interactive(pt_path, title="Spectrogram"): |
|
|
try: |
|
|
|
|
|
spec_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
if spec_tensor.dim() == 4: |
|
|
spec_tensor = spec_tensor[0] |
|
|
if spec_tensor.dim() == 3: |
|
|
spec_data = spec_tensor.mean(dim=0).numpy() |
|
|
else: |
|
|
spec_data = spec_tensor.numpy() |
|
|
|
|
|
if spec_data.min() >= 0: |
|
|
spec_data = np.log1p(spec_data) |
|
|
|
|
|
fig = px.imshow( |
|
|
spec_data, |
|
|
origin='lower', |
|
|
aspect='auto', |
|
|
color_continuous_scale='Viridis', |
|
|
labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"), |
|
|
title=title |
|
|
) |
|
|
fig.update_layout(margin=dict(l=0, r=0, t=30, b=0), height=300) |
|
|
return fig |
|
|
except Exception as e: |
|
|
st.error(f"Error loading spectrogram: {e}") |
|
|
return None |
|
|
|
|
|
def load_feature_map_interactive(pt_path): |
|
|
try: |
|
|
|
|
|
feat_tensor = torch.load(pt_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
if feat_tensor.dim() == 4: |
|
|
feat_tensor = feat_tensor[0] |
|
|
|
|
|
mean_activation = feat_tensor.mean(dim=0).numpy() |
|
|
|
|
|
fig = px.imshow( |
|
|
mean_activation, |
|
|
origin='lower', |
|
|
aspect='auto', |
|
|
color_continuous_scale='Viridis', |
|
|
labels=dict(x="Time", y="Freq/Feature", color="Activation"), |
|
|
title=f"Mean Activation (Shape: {list(feat_tensor.shape)})" |
|
|
) |
|
|
fig.update_layout(margin=dict(l=0, r=0, t=40, b=0)) |
|
|
return fig |
|
|
except Exception as e: |
|
|
return None |
|
|
|
|
|
st.title("🎵 Audio Source Separation Inspector") |
|
|
|
|
|
st.markdown(""" |
|
|
### Model Interpretation Guide |
|
|
This tool helps you evaluate how well the model separates audio sources. |
|
|
* **Audio Quality:** Listen for "artifacts" (robotic sounds or clicking) in the Prediction compared to the Target. |
|
|
* **Spectrogram Clarity:** In the visuals below, distinct horizontal lines represent clear tones. Vertical smear usually indicates percussion or noise. |
|
|
* **Error Analysis:** If the Prediction looks "blurry" compared to the Target, the model is losing high-frequency details. |
|
|
""") |
|
|
|
|
|
if not os.path.exists(BASE_PATH): |
|
|
st.error(f"Models directory not found at {BASE_PATH}. Please ensure your data was uploaded correctly.") |
|
|
st.stop() |
|
|
|
|
|
models = get_subdirs(BASE_PATH) |
|
|
selected_model = st.sidebar.selectbox("Select Model", models) |
|
|
|
|
|
st.sidebar.markdown("### Audio Settings") |
|
|
volume_boost = st.sidebar.slider( |
|
|
"Volume Boost (Gain)", |
|
|
min_value=1.0, |
|
|
max_value=20.0, |
|
|
value=1.0, |
|
|
step=0.5, |
|
|
help="Digitally increases the amplitude of the audio signal." |
|
|
) |
|
|
|
|
|
if selected_model: |
|
|
model_path = os.path.join(BASE_PATH, selected_model) |
|
|
artifacts_path = os.path.join(model_path, "test_artifacts") |
|
|
|
|
|
if os.path.exists(artifacts_path): |
|
|
samples = get_subdirs(artifacts_path) |
|
|
samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0) |
|
|
|
|
|
selected_sample = st.sidebar.selectbox("Select Sample ID", samples) |
|
|
|
|
|
if selected_sample: |
|
|
sample_path = os.path.join(artifacts_path, selected_sample) |
|
|
audio_dir = os.path.join(sample_path, "audio") |
|
|
specs_dir = os.path.join(sample_path, "specs") |
|
|
feats_dir = os.path.join(sample_path, "feats") |
|
|
|
|
|
all_files = os.listdir(audio_dir) |
|
|
target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")] |
|
|
classes = [f.replace("target_", "").replace(".wav", "") for f in target_files] |
|
|
|
|
|
selected_class = st.sidebar.selectbox("Focus Class", classes) |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["🎧 Audio & Spectrograms", "🧠 Internal Activations", "📊 Model Metadata"]) |
|
|
|
|
|
with tab1: |
|
|
st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}") |
|
|
|
|
|
st.subheader("1. Mixture (Input)") |
|
|
st.markdown("The raw input containing all sound sources mixed together.") |
|
|
mix_audio = os.path.join(audio_dir, "mixture.wav") |
|
|
mix_spec = os.path.join(specs_dir, "mixture.pt") |
|
|
|
|
|
c1, c2 = st.columns([1, 3]) |
|
|
with c1: |
|
|
if os.path.exists(mix_audio): |
|
|
st.markdown("**Audio:**") |
|
|
processed_mix = process_audio(mix_audio, volume_boost) |
|
|
if processed_mix: |
|
|
st.audio(processed_mix, format='audio/wav') |
|
|
with c2: |
|
|
if os.path.exists(mix_spec): |
|
|
fig = load_spectrogram_interactive(mix_spec, title="Mixture Mel-Spectrogram") |
|
|
if fig: st.plotly_chart(fig, width='stretch') |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader(f"2. Target: {selected_class}") |
|
|
st.markdown(f"**Interpretation:** This is the 'Ground Truth'. Look at the spectrogram structure here—this is the ideal output.") |
|
|
tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav") |
|
|
tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt") |
|
|
|
|
|
c1, c2 = st.columns([1, 3]) |
|
|
with c1: |
|
|
if os.path.exists(tgt_audio): |
|
|
st.markdown("**Audio:**") |
|
|
processed_tgt = process_audio(tgt_audio, volume_boost) |
|
|
if processed_tgt: |
|
|
st.audio(processed_tgt, format='audio/wav') |
|
|
with c2: |
|
|
if os.path.exists(tgt_spec): |
|
|
fig = load_spectrogram_interactive(tgt_spec, title=f"Target Mel-Spectrogram ({selected_class})") |
|
|
if fig: st.plotly_chart(fig, width='stretch') |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader(f"3. Prediction: {selected_class}") |
|
|
st.markdown(f"**Interpretation:** Compare this to the Target above. If you see 'fuzziness' in the dark areas, the model is not silencing background noise correctly.") |
|
|
pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav") |
|
|
pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt") |
|
|
|
|
|
c1, c2 = st.columns([1, 3]) |
|
|
with c1: |
|
|
if os.path.exists(pred_audio): |
|
|
st.markdown("**Audio:**") |
|
|
processed_pred = process_audio(pred_audio, volume_boost) |
|
|
if processed_pred: |
|
|
st.audio(processed_pred, format='audio/wav') |
|
|
with c2: |
|
|
if os.path.exists(pred_spec): |
|
|
fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Mel-Spectrogram ({selected_class})") |
|
|
if fig: st.plotly_chart(fig, width='stretch') |
|
|
|
|
|
with tab2: |
|
|
st.header("Internal Feature Maps") |
|
|
|
|
|
st.markdown("These heatmaps visualize the neural network's internal state. Bright spots indicate features the model considers important for separation.") |
|
|
|
|
|
if os.path.exists(feats_dir): |
|
|
feat_files = sorted(os.listdir(feats_dir)) |
|
|
|
|
|
if feat_files: |
|
|
selected_layer = st.selectbox("Select Probed Layer", feat_files) |
|
|
if selected_layer: |
|
|
st.write(f"Layer: **{selected_layer.replace('.pt', '')}**") |
|
|
fig = load_feature_map_interactive(os.path.join(feats_dir, selected_layer)) |
|
|
if fig: |
|
|
st.plotly_chart(fig, width='stretch') |
|
|
else: |
|
|
st.warning("No feature maps found for this sample.") |
|
|
else: |
|
|
st.error("Features directory not found.") |
|
|
|
|
|
with tab3: |
|
|
st.header("Training and Testing Logs") |
|
|
|
|
|
st.markdown("Use these graphs to check for **Overfitting**. If Training Loss decreases but Test Metrics stagnate or drop, the model is memorizing data rather than learning general features.") |
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
with c1: |
|
|
results_csv = os.path.join(model_path, "test_results.csv") |
|
|
if os.path.exists(results_csv): |
|
|
st.subheader("Test Metrics") |
|
|
df = pd.read_csv(results_csv) |
|
|
x_axis = 'Batch_Index' if 'Batch_Index' in df.columns else df.index |
|
|
numeric_cols = df.select_dtypes(include=np.number).columns |
|
|
fig = px.line(df, title="Test Metrics", x=x_axis, y=numeric_cols) |
|
|
st.plotly_chart(fig, width='stretch') |
|
|
st.dataframe(df, width='stretch') |
|
|
else: |
|
|
st.info("No `test_results.csv` found.") |
|
|
|
|
|
with c2: |
|
|
loss_csv = os.path.join(model_path, "loss.csv") |
|
|
if os.path.exists(loss_csv): |
|
|
st.subheader("Training Loss") |
|
|
try: |
|
|
df_loss = pd.read_csv(loss_csv) |
|
|
x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index |
|
|
|
|
|
numeric_cols = df_loss.select_dtypes(include=np.number).columns |
|
|
fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves") |
|
|
st.plotly_chart(fig, width='stretch') |
|
|
st.dataframe(df_loss, width='stretch') |
|
|
except Exception as e: |
|
|
st.write("Could not parse `loss.csv`.", e) |
|
|
else: |
|
|
st.info("No `loss.csv` found.") |
|
|
|
|
|
else: |
|
|
st.warning(f"No 'test_artifacts' folder found in {selected_model}") |