Kiuyha's picture
Update app.py
aa490af verified
import streamlit as st
import os
import torch
import numpy as np
import pandas as pd
import plotly.express as px
import scipy.io.wavfile
import io
BASE_PATH = 'Models'
st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector")
def process_audio(file_path, gain_factor):
try:
# 1. FIX: Check if file is actually a Git LFS pointer (text file)
with open(file_path, 'rb') as f:
header = f.read(50)
if header.startswith(b'version https://git-lfs'):
st.error(f"❌ **LFS Error:** `{os.path.basename(file_path)}` is a Git LFS pointer, not a WAV file. Run `git lfs pull` in your terminal.")
return None
sample_rate, data = scipy.io.wavfile.read(file_path)
if data.dtype == np.int16:
data = data.astype(np.float32) / 32768.0
elif data.dtype == np.int32:
data = data.astype(np.float32) / 2147483648.0
data = data * gain_factor
data = np.clip(data, -1.0, 1.0)
data = (data * 32767).astype(np.int16)
virtual_file = io.BytesIO()
scipy.io.wavfile.write(virtual_file, sample_rate, data)
return virtual_file
except Exception as e:
st.error(f"Error processing audio: {e}")
return file_path
def get_subdirs(path):
if not os.path.exists(path):
return []
return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
def load_spectrogram_interactive(pt_path, title="Spectrogram"):
try:
# 2. FIX: Added weights_only=False to fix PyTorch 2.6+ error
spec_tensor = torch.load(pt_path, map_location='cpu', weights_only=False)
if spec_tensor.dim() == 4:
spec_tensor = spec_tensor[0]
if spec_tensor.dim() == 3:
spec_data = spec_tensor.mean(dim=0).numpy()
else:
spec_data = spec_tensor.numpy()
if spec_data.min() >= 0:
spec_data = np.log1p(spec_data)
fig = px.imshow(
spec_data,
origin='lower',
aspect='auto',
color_continuous_scale='Viridis',
labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"),
title=title
)
fig.update_layout(margin=dict(l=0, r=0, t=30, b=0), height=300)
return fig
except Exception as e:
st.error(f"Error loading spectrogram: {e}")
return None
def load_feature_map_interactive(pt_path):
try:
# 3. FIX: Added weights_only=False here as well
feat_tensor = torch.load(pt_path, map_location='cpu', weights_only=False)
if feat_tensor.dim() == 4:
feat_tensor = feat_tensor[0]
mean_activation = feat_tensor.mean(dim=0).numpy()
fig = px.imshow(
mean_activation,
origin='lower',
aspect='auto',
color_continuous_scale='Viridis',
labels=dict(x="Time", y="Freq/Feature", color="Activation"),
title=f"Mean Activation (Shape: {list(feat_tensor.shape)})"
)
fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
return fig
except Exception as e:
return None
st.title("🎵 Audio Source Separation Inspector")
st.markdown("""
### Model Interpretation Guide
This tool helps you evaluate how well the model separates audio sources.
* **Audio Quality:** Listen for "artifacts" (robotic sounds or clicking) in the Prediction compared to the Target.
* **Spectrogram Clarity:** In the visuals below, distinct horizontal lines represent clear tones. Vertical smear usually indicates percussion or noise.
* **Error Analysis:** If the Prediction looks "blurry" compared to the Target, the model is losing high-frequency details.
""")
if not os.path.exists(BASE_PATH):
st.error(f"Models directory not found at {BASE_PATH}. Please ensure your data was uploaded correctly.")
st.stop()
models = get_subdirs(BASE_PATH)
selected_model = st.sidebar.selectbox("Select Model", models)
st.sidebar.markdown("### Audio Settings")
volume_boost = st.sidebar.slider(
"Volume Boost (Gain)",
min_value=1.0,
max_value=20.0,
value=1.0,
step=0.5,
help="Digitally increases the amplitude of the audio signal."
)
if selected_model:
model_path = os.path.join(BASE_PATH, selected_model)
artifacts_path = os.path.join(model_path, "test_artifacts")
if os.path.exists(artifacts_path):
samples = get_subdirs(artifacts_path)
samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0)
selected_sample = st.sidebar.selectbox("Select Sample ID", samples)
if selected_sample:
sample_path = os.path.join(artifacts_path, selected_sample)
audio_dir = os.path.join(sample_path, "audio")
specs_dir = os.path.join(sample_path, "specs")
feats_dir = os.path.join(sample_path, "feats")
all_files = os.listdir(audio_dir)
target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")]
classes = [f.replace("target_", "").replace(".wav", "") for f in target_files]
selected_class = st.sidebar.selectbox("Focus Class", classes)
tab1, tab2, tab3 = st.tabs(["🎧 Audio & Spectrograms", "🧠 Internal Activations", "📊 Model Metadata"])
with tab1:
st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}")
st.subheader("1. Mixture (Input)")
st.markdown("The raw input containing all sound sources mixed together.")
mix_audio = os.path.join(audio_dir, "mixture.wav")
mix_spec = os.path.join(specs_dir, "mixture.pt")
c1, c2 = st.columns([1, 3])
with c1:
if os.path.exists(mix_audio):
st.markdown("**Audio:**")
processed_mix = process_audio(mix_audio, volume_boost)
if processed_mix:
st.audio(processed_mix, format='audio/wav')
with c2:
if os.path.exists(mix_spec):
fig = load_spectrogram_interactive(mix_spec, title="Mixture Mel-Spectrogram")
if fig: st.plotly_chart(fig, width='stretch')
st.divider()
st.subheader(f"2. Target: {selected_class}")
st.markdown(f"**Interpretation:** This is the 'Ground Truth'. Look at the spectrogram structure here—this is the ideal output.")
tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav")
tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt")
c1, c2 = st.columns([1, 3])
with c1:
if os.path.exists(tgt_audio):
st.markdown("**Audio:**")
processed_tgt = process_audio(tgt_audio, volume_boost)
if processed_tgt:
st.audio(processed_tgt, format='audio/wav')
with c2:
if os.path.exists(tgt_spec):
fig = load_spectrogram_interactive(tgt_spec, title=f"Target Mel-Spectrogram ({selected_class})")
if fig: st.plotly_chart(fig, width='stretch')
st.divider()
st.subheader(f"3. Prediction: {selected_class}")
st.markdown(f"**Interpretation:** Compare this to the Target above. If you see 'fuzziness' in the dark areas, the model is not silencing background noise correctly.")
pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav")
pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt")
c1, c2 = st.columns([1, 3])
with c1:
if os.path.exists(pred_audio):
st.markdown("**Audio:**")
processed_pred = process_audio(pred_audio, volume_boost)
if processed_pred:
st.audio(processed_pred, format='audio/wav')
with c2:
if os.path.exists(pred_spec):
fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Mel-Spectrogram ({selected_class})")
if fig: st.plotly_chart(fig, width='stretch')
with tab2:
st.header("Internal Feature Maps")
st.markdown("These heatmaps visualize the neural network's internal state. Bright spots indicate features the model considers important for separation.")
if os.path.exists(feats_dir):
feat_files = sorted(os.listdir(feats_dir))
if feat_files:
selected_layer = st.selectbox("Select Probed Layer", feat_files)
if selected_layer:
st.write(f"Layer: **{selected_layer.replace('.pt', '')}**")
fig = load_feature_map_interactive(os.path.join(feats_dir, selected_layer))
if fig:
st.plotly_chart(fig, width='stretch')
else:
st.warning("No feature maps found for this sample.")
else:
st.error("Features directory not found.")
with tab3:
st.header("Training and Testing Logs")
st.markdown("Use these graphs to check for **Overfitting**. If Training Loss decreases but Test Metrics stagnate or drop, the model is memorizing data rather than learning general features.")
c1, c2 = st.columns(2)
with c1:
results_csv = os.path.join(model_path, "test_results.csv")
if os.path.exists(results_csv):
st.subheader("Test Metrics")
df = pd.read_csv(results_csv)
x_axis = 'Batch_Index' if 'Batch_Index' in df.columns else df.index
numeric_cols = df.select_dtypes(include=np.number).columns
fig = px.line(df, title="Test Metrics", x=x_axis, y=numeric_cols)
st.plotly_chart(fig, width='stretch')
st.dataframe(df, width='stretch')
else:
st.info("No `test_results.csv` found.")
with c2:
loss_csv = os.path.join(model_path, "loss.csv")
if os.path.exists(loss_csv):
st.subheader("Training Loss")
try:
df_loss = pd.read_csv(loss_csv)
x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index
numeric_cols = df_loss.select_dtypes(include=np.number).columns
fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves")
st.plotly_chart(fig, width='stretch')
st.dataframe(df_loss, width='stretch')
except Exception as e:
st.write("Could not parse `loss.csv`.", e)
else:
st.info("No `loss.csv` found.")
else:
st.warning(f"No 'test_artifacts' folder found in {selected_model}")