Kiuyha commited on
Commit
7d79d81
·
verified ·
1 Parent(s): 039abcb

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
  import plotly.express as px
8
  import plotly.graph_objects as go
9
 
10
- BASE_PATH = 'Models'
11
 
12
  st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector")
13
 
@@ -21,7 +21,7 @@ def load_spectrogram_interactive(pt_path, title="Spectrogram"):
21
  """Loads a .pt spectrogram and returns a Plotly figure."""
22
  try:
23
  spec_tensor = torch.load(pt_path, map_location='cpu')
24
-
25
  # Handle dimensions: [Channels, Freq, Time]
26
  if spec_tensor.dim() == 4: # [Batch, C, F, T]
27
  spec_tensor = spec_tensor[0]
@@ -33,13 +33,13 @@ def load_spectrogram_interactive(pt_path, title="Spectrogram"):
33
  # Log scaling for better visibility
34
  if spec_data.min() >= 0:
35
  spec_data = np.log1p(spec_data)
36
-
37
  # Create interactive heatmap
38
  fig = px.imshow(
39
- spec_data,
40
- origin='lower',
41
- aspect='auto',
42
- color_continuous_scale='Magma',
43
  labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"),
44
  title=title
45
  )
@@ -53,21 +53,21 @@ def load_feature_map_interactive(pt_path):
53
  """Loads an internal feature map and visualizes its mean activation interactively."""
54
  try:
55
  feat_tensor = torch.load(pt_path, map_location='cpu')
56
-
57
  # Squeeze batch if present
58
- if feat_tensor.dim() == 4:
59
  feat_tensor = feat_tensor[0]
60
-
61
  # feat_tensor is likely [Channels, Freq, Time]
62
  mean_activation = feat_tensor.mean(dim=0).numpy()
63
-
64
  fig = px.imshow(
65
- mean_activation,
66
- origin='lower',
67
- aspect='auto',
68
- color_continuous_scale='Viridis',
69
- labels=dict(x="Time", y="Freq/Feature", color="Activation"),
70
- title=f"Mean Activation (Shape: {list(feat_tensor.shape)})"
71
  )
72
  fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
73
  return fig
@@ -88,26 +88,26 @@ selected_model = st.sidebar.selectbox("Select Model", models)
88
  if selected_model:
89
  model_path = os.path.join(BASE_PATH, selected_model)
90
  artifacts_path = os.path.join(model_path, "test_artifacts")
91
-
92
  # 2. Select Sample
93
  if os.path.exists(artifacts_path):
94
  samples = get_subdirs(artifacts_path)
95
  # Sort samples numerically
96
  samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0)
97
-
98
  selected_sample = st.sidebar.selectbox("Select Sample ID", samples)
99
-
100
  if selected_sample:
101
  sample_path = os.path.join(artifacts_path, selected_sample)
102
  audio_dir = os.path.join(sample_path, "audio")
103
  specs_dir = os.path.join(sample_path, "specs")
104
  feats_dir = os.path.join(sample_path, "feats")
105
-
106
  # 3. Detect Classes
107
  all_files = os.listdir(audio_dir)
108
  target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")]
109
  classes = [f.replace("target_", "").replace(".wav", "") for f in target_files]
110
-
111
  # Sidebar Class Filter
112
  selected_class = st.sidebar.selectbox("Focus Class", classes)
113
 
@@ -116,12 +116,12 @@ if selected_model:
116
 
117
  with tab1:
118
  st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}")
119
-
120
  # --- Mixture (Input) ---
121
  st.subheader("1. Mixture (Input)")
122
  mix_audio = os.path.join(audio_dir, "mixture.wav")
123
  mix_spec = os.path.join(specs_dir, "mixture.pt")
124
-
125
  c1, c2 = st.columns([1, 3]) # Audio on left, Graph on right (wider)
126
  with c1:
127
  if os.path.exists(mix_audio):
@@ -129,16 +129,16 @@ if selected_model:
129
  st.audio(mix_audio)
130
  with c2:
131
  if os.path.exists(mix_spec):
132
- fig = load_spectrogram_interactive(mix_spec, title="Mixture Spectrogram")
133
  if fig: st.plotly_chart(fig, width='stretch')
134
-
135
  st.divider()
136
 
137
  # --- Target (Ground Truth) ---
138
  st.subheader(f"2. Target: {selected_class}")
139
  tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav")
140
  tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt")
141
-
142
  c1, c2 = st.columns([1, 3])
143
  with c1:
144
  if os.path.exists(tgt_audio):
@@ -146,16 +146,16 @@ if selected_model:
146
  st.audio(tgt_audio)
147
  with c2:
148
  if os.path.exists(tgt_spec):
149
- fig = load_spectrogram_interactive(tgt_spec, title=f"Target Spectrogram ({selected_class})")
150
  if fig: st.plotly_chart(fig, width='stretch')
151
-
152
  st.divider()
153
 
154
  # --- Prediction (Output) ---
155
  st.subheader(f"3. Prediction: {selected_class}")
156
  pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav")
157
  pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt")
158
-
159
  c1, c2 = st.columns([1, 3])
160
  with c1:
161
  if os.path.exists(pred_audio):
@@ -163,15 +163,15 @@ if selected_model:
163
  st.audio(pred_audio)
164
  with c2:
165
  if os.path.exists(pred_spec):
166
- fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Spectrogram ({selected_class})")
167
  if fig: st.plotly_chart(fig, width='stretch')
168
-
169
  with tab2:
170
  st.header("Internal Feature Maps")
171
-
172
  if os.path.exists(feats_dir):
173
  feat_files = sorted(os.listdir(feats_dir))
174
-
175
  if feat_files:
176
  selected_layer = st.selectbox("Select Probed Layer", feat_files)
177
  if selected_layer:
@@ -186,7 +186,7 @@ if selected_model:
186
 
187
  with tab3:
188
  st.header("Training and Testing Logs")
189
-
190
  c1, c2 = st.columns(2)
191
  with c1:
192
  results_csv = os.path.join(model_path, "test_results.csv")
@@ -199,7 +199,7 @@ if selected_model:
199
  st.dataframe(df, width='stretch')
200
  else:
201
  st.info("No `test_results.csv` found.")
202
-
203
  with c2:
204
  loss_csv = os.path.join(model_path, "loss.csv")
205
  if os.path.exists(loss_csv):
@@ -208,7 +208,7 @@ if selected_model:
208
  df_loss = pd.read_csv(loss_csv)
209
  # Try to find an epoch column, otherwise use index
210
  x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index
211
-
212
  # Melt if multiple loss columns exist for better visualization
213
  numeric_cols = df_loss.select_dtypes(include=np.number).columns
214
  fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves")
 
7
  import plotly.express as px
8
  import plotly.graph_objects as go
9
 
10
+ BASE_PATH = 'Models'
11
 
12
  st.set_page_config(layout="wide", page_title="Audio Source Separation Inspector")
13
 
 
21
  """Loads a .pt spectrogram and returns a Plotly figure."""
22
  try:
23
  spec_tensor = torch.load(pt_path, map_location='cpu')
24
+
25
  # Handle dimensions: [Channels, Freq, Time]
26
  if spec_tensor.dim() == 4: # [Batch, C, F, T]
27
  spec_tensor = spec_tensor[0]
 
33
  # Log scaling for better visibility
34
  if spec_data.min() >= 0:
35
  spec_data = np.log1p(spec_data)
36
+
37
  # Create interactive heatmap
38
  fig = px.imshow(
39
+ spec_data,
40
+ origin='lower',
41
+ aspect='auto',
42
+ color_continuous_scale='Viridis',
43
  labels=dict(x="Time Frame", y="Frequency Bin", color="Log Magnitude"),
44
  title=title
45
  )
 
53
  """Loads an internal feature map and visualizes its mean activation interactively."""
54
  try:
55
  feat_tensor = torch.load(pt_path, map_location='cpu')
56
+
57
  # Squeeze batch if present
58
+ if feat_tensor.dim() == 4:
59
  feat_tensor = feat_tensor[0]
60
+
61
  # feat_tensor is likely [Channels, Freq, Time]
62
  mean_activation = feat_tensor.mean(dim=0).numpy()
63
+
64
  fig = px.imshow(
65
+ mean_activation,
66
+ origin='lower',
67
+ aspect='auto',
68
+ color_continuous_scale='Viridis',
69
+ labels=dict(x="Time", y="Freq/Feature", color="Activation"),
70
+ title=f"Mean Activation (Shape: {list(feat_tensor.shape)})"
71
  )
72
  fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
73
  return fig
 
88
  if selected_model:
89
  model_path = os.path.join(BASE_PATH, selected_model)
90
  artifacts_path = os.path.join(model_path, "test_artifacts")
91
+
92
  # 2. Select Sample
93
  if os.path.exists(artifacts_path):
94
  samples = get_subdirs(artifacts_path)
95
  # Sort samples numerically
96
  samples.sort(key=lambda x: int(x.split('_')[-1]) if '_' in x else 0)
97
+
98
  selected_sample = st.sidebar.selectbox("Select Sample ID", samples)
99
+
100
  if selected_sample:
101
  sample_path = os.path.join(artifacts_path, selected_sample)
102
  audio_dir = os.path.join(sample_path, "audio")
103
  specs_dir = os.path.join(sample_path, "specs")
104
  feats_dir = os.path.join(sample_path, "feats")
105
+
106
  # 3. Detect Classes
107
  all_files = os.listdir(audio_dir)
108
  target_files = [f for f in all_files if f.startswith("target_") and f.endswith(".wav")]
109
  classes = [f.replace("target_", "").replace(".wav", "") for f in target_files]
110
+
111
  # Sidebar Class Filter
112
  selected_class = st.sidebar.selectbox("Focus Class", classes)
113
 
 
116
 
117
  with tab1:
118
  st.header(f"Sample {selected_sample} | Focus: {selected_class.capitalize()}")
119
+
120
  # --- Mixture (Input) ---
121
  st.subheader("1. Mixture (Input)")
122
  mix_audio = os.path.join(audio_dir, "mixture.wav")
123
  mix_spec = os.path.join(specs_dir, "mixture.pt")
124
+
125
  c1, c2 = st.columns([1, 3]) # Audio on left, Graph on right (wider)
126
  with c1:
127
  if os.path.exists(mix_audio):
 
129
  st.audio(mix_audio)
130
  with c2:
131
  if os.path.exists(mix_spec):
132
+ fig = load_spectrogram_interactive(mix_spec, title="Mixture Mel-Spectrogram")
133
  if fig: st.plotly_chart(fig, width='stretch')
134
+
135
  st.divider()
136
 
137
  # --- Target (Ground Truth) ---
138
  st.subheader(f"2. Target: {selected_class}")
139
  tgt_audio = os.path.join(audio_dir, f"target_{selected_class}.wav")
140
  tgt_spec = os.path.join(specs_dir, f"target_{selected_class}.pt")
141
+
142
  c1, c2 = st.columns([1, 3])
143
  with c1:
144
  if os.path.exists(tgt_audio):
 
146
  st.audio(tgt_audio)
147
  with c2:
148
  if os.path.exists(tgt_spec):
149
+ fig = load_spectrogram_interactive(tgt_spec, title=f"Target Mel-Spectrogram ({selected_class})")
150
  if fig: st.plotly_chart(fig, width='stretch')
151
+
152
  st.divider()
153
 
154
  # --- Prediction (Output) ---
155
  st.subheader(f"3. Prediction: {selected_class}")
156
  pred_audio = os.path.join(audio_dir, f"pred_{selected_class}.wav")
157
  pred_spec = os.path.join(specs_dir, f"pred_{selected_class}.pt")
158
+
159
  c1, c2 = st.columns([1, 3])
160
  with c1:
161
  if os.path.exists(pred_audio):
 
163
  st.audio(pred_audio)
164
  with c2:
165
  if os.path.exists(pred_spec):
166
+ fig = load_spectrogram_interactive(pred_spec, title=f"Predicted Mel-Spectrogram ({selected_class})")
167
  if fig: st.plotly_chart(fig, width='stretch')
168
+
169
  with tab2:
170
  st.header("Internal Feature Maps")
171
+
172
  if os.path.exists(feats_dir):
173
  feat_files = sorted(os.listdir(feats_dir))
174
+
175
  if feat_files:
176
  selected_layer = st.selectbox("Select Probed Layer", feat_files)
177
  if selected_layer:
 
186
 
187
  with tab3:
188
  st.header("Training and Testing Logs")
189
+
190
  c1, c2 = st.columns(2)
191
  with c1:
192
  results_csv = os.path.join(model_path, "test_results.csv")
 
199
  st.dataframe(df, width='stretch')
200
  else:
201
  st.info("No `test_results.csv` found.")
202
+
203
  with c2:
204
  loss_csv = os.path.join(model_path, "loss.csv")
205
  if os.path.exists(loss_csv):
 
208
  df_loss = pd.read_csv(loss_csv)
209
  # Try to find an epoch column, otherwise use index
210
  x_axis = 'epoch' if 'epoch' in df_loss.columns else df_loss.index
211
+
212
  # Melt if multiple loss columns exist for better visualization
213
  numeric_cols = df_loss.select_dtypes(include=np.number).columns
214
  fig = px.line(df_loss, x=x_axis, y=numeric_cols, title="Loss Curves")