Kiuyha commited on
Commit
01d4d83
·
verified ·
1 Parent(s): 9fab65c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+
10
+ BASE_PATH = 'Models'
11
+
12
+ st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector")
13
+
14
+ def get_subdirs(path):
15
+ """Returns a list of subdirectories in a given path."""
16
+ if not os.path.exists(path):
17
+ return []
18
+ return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
19
+
20
+ def load_spectrogram_interactive(pt_path, title="Spectrogram"):
21
+ """Loads a .pt spectrogram and returns a Plotly figure."""
22
+ try:
23
+ spec_tensor = torch.load(pt_path, map_location='cpu')
24
+
25
+ # Handle dimensions: [Channels, Freq, Time]
26
+ if spec_tensor.dim() == 4: # [Batch, C, F, T]
27
+ spec_tensor = spec_tensor[0]
28
+ if spec_tensor.dim() == 3: # [C, F, T]
29
+ spec_data = spec_tensor.mean(dim=0).numpy() # Average across channels
30
+ else: # [F, T]
31
+ spec_data = spec_tensor.numpy()
32
+
33
+ # Log scaling for better visibility
34
+ if spec_data.min() >= 0:
35
+ spec_data = np.log1p(spec_data)
36
+
37
+ # Create interactive heatmap
38
+ fig = px.imshow(
39
+ spec_data,
40
+ origin='lower',
41
+ aspect='auto',
42
+ color_continuous_scale='Magma',
43
+ labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"),
44
+ title=title
45
+ )
46
+ fig.update_layout(margin=dict(l=0, r=0, t=30, b=0), height=300)
47
+ return fig
48
+ except Exception as e:
49
+ st.error(f"Error loading spectrogram: {e}")
50
+ return None
51
+
52
+ def load_feature_map_interactive(pt_path):
53
+ """Loads an internal feature map and visualizes its mean activation interactively."""
54
+ try:
55
+ feat_tensor = torch.load(pt_path, map_location='cpu')
56
+
57
+ # Squeeze batch if present
58
+ if feat_tensor.dim() == 4:
59
+ feat_tensor = feat_tensor[0]
60
+
61
+ # feat_tensor is likely [Channels, Freq, Time]
62
+ mean_activation = feat_tensor.mean(dim=0).numpy()
63
+
64
+ fig = px.imshow(
65
+ mean_activation,
66
+ origin='lower',
67
+ aspect='auto',
68
+ color_continuous_scale='Viridis',
69
+ labels=dict(x="Time", y="Freq/Feature", color="Activation"),
70
+ title=f"Mean Activation (Shape: {list(feat_tensor.shape)})"
71
+ )
72
+ fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
73
+ return fig
74
+ except Exception as e:
75
+ return None
76
+
77
+ st.title("🎵 Audio Source Separation Inspector")
78
+
79
+ # Check if data uploaded correctly
80
+ if not os.path.exists(BASE_PATH):
81
+ st.error(f"Models directory not found at {BASE_PATH}. Please ensure your data was uploaded correctly.")
82
+ st.stop()
83
+
84
+ # 1. Select Model
85
+ models = get_subdirs(BASE_PATH)
86
+ selected_model = st.sidebar.selectbox("Select Model", models)
87
+
88
+ if selected_model:
89
+ model_path = os.path.join(BASE_PATH, selected_model)
90
+ artifacts_path = os.path.join(model_path, "test_artifacts")
91
+
92
+ # 2. Select Sample
93
+ if os.path.exists(artifacts_path):
94
+ samples = get_subdirs(artifacts_path)
95
+ # Sort samples numerically
96
+ samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0)
97
+
98
+ selected_sample = st.sidebar.selectbox("Select Sample ID", samples)
99
+
100
+ if selected_sample:
101
+ sample_path = os.path.join(artifacts_path, selected_sample)
102
+ audio_dir = os.path.join(sample_path, "audio")
103
+ specs_dir = os.path.join(sample_path, "specs")
104
+ feats_dir = os.path.join(sample_path, "feats")
105
+
106
+ # 3. Detect Classes
107
+ all_files = os.listdir(audio_dir)
108
+ target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")]
109
+ classes = [f.replace("target_", "").replace(".wav", "") for f in target_files]
110
+
111
+ # Sidebar Class Filter
112
+ selected_class = st.sidebar.selectbox("Focus Class", classes)
113
+
114
+ # --- MAIN CONTENT TABS ---
115
+ tab1, tab2, tab3 = st.tabs(["🎧 Audio & Spectrograms", "🧠 Internal Activations", "📊 Model Metadata"])
116
+
117
+ with tab1:
118
+ st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}")
119
+
120
+ # --- Mixture (Input) ---
121
+ st.subheader("1. Mixture (Input)")
122
+ mix_audio = os.path.join(audio_dir, "mixture.wav")
123
+ mix_spec = os.path.join(specs_dir, "mixture.pt")
124
+
125
+ c1, c2 = st.columns([1, 3]) # Audio on left, Graph on right (wider)
126
+ with c1:
127
+ if os.path.exists(mix_audio):
128
+ st.markdown("**Audio:**")
129
+ st.audio(mix_audio)
130
+ with c2:
131
+ if os.path.exists(mix_spec):
132
+ fig = load_spectrogram_interactive(mix_spec, title="Mixture Spectrogram")
133
+ if fig: st.plotly_chart(fig, width='stretch')
134
+
135
+ st.divider()
136
+
137
+ # --- Target (Ground Truth) ---
138
+ st.subheader(f"2. Target: {selected_class}")
139
+ tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav")
140
+ tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt")
141
+
142
+ c1, c2 = st.columns([1, 3])
143
+ with c1:
144
+ if os.path.exists(tgt_audio):
145
+ st.markdown("**Audio:**")
146
+ st.audio(tgt_audio)
147
+ with c2:
148
+ if os.path.exists(tgt_spec):
149
+ fig = load_spectrogram_interactive(tgt_spec, title=f"Target Spectrogram ({selected_class})")
150
+ if fig: st.plotly_chart(fig, width='stretch')
151
+
152
+ st.divider()
153
+
154
+ # --- Prediction (Output) ---
155
+ st.subheader(f"3. Prediction: {selected_class}")
156
+ pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav")
157
+ pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt")
158
+
159
+ c1, c2 = st.columns([1, 3])
160
+ with c1:
161
+ if os.path.exists(pred_audio):
162
+ st.markdown("**Audio:**")
163
+ st.audio(pred_audio)
164
+ with c2:
165
+ if os.path.exists(pred_spec):
166
+ fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Spectrogram ({selected_class})")
167
+ if fig: st.plotly_chart(fig, width='stretch')
168
+
169
+ with tab2:
170
+ st.header("Internal Feature Maps")
171
+
172
+ if os.path.exists(feats_dir):
173
+ feat_files = sorted(os.listdir(feats_dir))
174
+
175
+ if feat_files:
176
+ selected_layer = st.selectbox("Select Probed Layer", feat_files)
177
+ if selected_layer:
178
+ st.write(f"Layer: **{selected_layer.replace('.pt', '')}**")
179
+ fig = load_feature_map_interactive(os.path.join(feats_dir, selected_layer))
180
+ if fig:
181
+ st.plotly_chart(fig, width='stretch')
182
+ else:
183
+ st.warning("No feature maps found for this sample.")
184
+ else:
185
+ st.error("Features directory not found.")
186
+
187
+ with tab3:
188
+ st.header("Training and Testing Logs")
189
+
190
+ c1, c2 = st.columns(2)
191
+ with c1:
192
+ results_csv = os.path.join(model_path, "test_results.csv")
193
+ if os.path.exists(results_csv):
194
+ st.subheader("Test Results")
195
+ df = pd.read_csv(results_csv)
196
+ # Use Plotly for the table/chart
197
+ fig = px.line(df, title="Test Metrics")
198
+ st.plotly_chart(fig, width='stretch')
199
+ st.dataframe(df, width='stretch')
200
+ else:
201
+ st.info("No `test_results.csv` found.")
202
+
203
+ with c2:
204
+ loss_csv = os.path.join(model_path, "loss.csv")
205
+ if os.path.exists(loss_csv):
206
+ st.subheader("Training Loss")
207
+ try:
208
+ df_loss = pd.read_csv(loss_csv)
209
+ # Try to find an epoch column, otherwise use index
210
+ x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index
211
+
212
+ # Melt if multiple loss columns exist for better visualization
213
+ numeric_cols = df_loss.select_dtypes(include=np.number).columns
214
+ fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves")
215
+ st.plotly_chart(fig, width='stretch')
216
+ st.dataframe(df_loss, width='stretch')
217
+ except Exception as e:
218
+ st.write("Could not parse `loss.csv`.", e)
219
+ else:
220
+ st.info("No `loss.csv` found.")
221
+
222
+ else:
223
+ st.warning(f"No 'test_artifacts' folder found in {selected_model}")