AdityaK007 commited on
Commit
b43ea87
·
verified ·
1 Parent(s): 16d47ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -82
app.py CHANGED
@@ -7,6 +7,7 @@ from sklearn.preprocessing import StandardScaler
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from scipy import signal
9
  from scipy.signal import get_window as scipy_get_window
 
10
  import plotly.express as px
11
  import plotly.graph_objects as go
12
  import os
@@ -156,44 +157,31 @@ def compare_frames_enhanced(near_feats, far_feats, metrics):
156
  return pd.DataFrame(results)
157
 
158
  # ----------------------------
159
- # 5. Dual Clustering Logic
160
  # ----------------------------
161
  def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters, eps):
162
- """
163
- Fits clustering on Near Field (clean), then predicts on Far Field (noisy).
164
- This ensures Cluster 0 in Near corresponds to the same physical sound in Far.
165
- """
166
- if not cluster_features:
167
- return near_df, far_df
168
 
169
  valid_features = [f for f in cluster_features if f in near_df.columns]
170
- if not valid_features:
171
- return near_df, far_df
172
 
173
- X_near = near_df[valid_features].values
174
- X_near = np.nan_to_num(X_near)
175
-
176
- X_far = far_df[valid_features].values
177
- X_far = np.nan_to_num(X_far)
178
 
179
- # We use a Scaler to ensure features are comparable
180
  scaler = StandardScaler()
181
  X_near_scaled = scaler.fit_transform(X_near)
182
- X_far_scaled = scaler.transform(X_far) # Use same scaler for Far
183
 
184
  if algo == "KMeans":
185
  model = KMeans(n_clusters=min(n_clusters, len(X_near)), random_state=42, n_init=10)
186
  near_labels = model.fit_predict(X_near_scaled)
187
- far_labels = model.predict(X_far_scaled) # Predict using Near model
188
  elif algo == "Agglomerative":
189
- # Agglomerative cannot "predict" on new data easily, so we cluster independently
190
- # This is a limitation, but acceptable fallback
191
  model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_near)))
192
  near_labels = model.fit_predict(X_near_scaled)
193
  far_model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_far)))
194
  far_labels = far_model.fit_predict(X_far_scaled)
195
  elif algo == "DBSCAN":
196
- # DBSCAN also cannot "predict", must fit_predict.
197
  model = DBSCAN(eps=eps, min_samples=3)
198
  near_labels = model.fit_predict(X_near_scaled)
199
  far_labels = model.fit_predict(X_far_scaled)
@@ -201,43 +189,70 @@ def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters,
201
  near_labels = np.zeros(len(X_near))
202
  far_labels = np.zeros(len(X_far))
203
 
204
- near_df = near_df.copy()
205
- near_df["cluster"] = near_labels
206
- near_df["cluster"] = near_df["cluster"].astype(str) # For categorical coloring
207
-
208
- far_df = far_df.copy()
209
- far_df["cluster"] = far_labels
210
- far_df["cluster"] = far_df["cluster"].astype(str)
211
 
212
  return near_df, far_df
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # ----------------------------
215
  # 6. Plotting Helpers
216
  # ----------------------------
217
  def generate_cluster_plot(df, x_attr, y_attr, title_suffix):
218
  if len(df) == 0 or x_attr not in df.columns or y_attr not in df.columns:
219
  return px.scatter(title="No Data")
220
-
221
  fig = px.scatter(
222
  df, x=x_attr, y=y_attr, color="cluster",
223
  title=f"Clustering Analysis ({title_suffix}): {x_attr} vs {y_attr}",
224
- color_discrete_sequence=px.colors.qualitative.Bold # Consistent colors
225
  )
226
  return fig
227
 
228
  def update_cluster_view(view_mode, near_df, far_df, cluster_features):
229
- if near_df is None or far_df is None:
230
- return px.scatter(title="Run Analysis First")
231
-
232
- if len(cluster_features) < 2:
233
- return px.scatter(title="Select at least 2 features")
234
-
235
  x_attr, y_attr = cluster_features[0], cluster_features[1]
236
-
237
- if view_mode == "Near Field":
238
- return generate_cluster_plot(near_df, x_attr, y_attr, "Near Field")
239
- else:
240
- return generate_cluster_plot(far_df, x_attr, y_attr, "Far Field")
241
 
242
  # ----------------------------
243
  # 7. Main Analysis
@@ -252,7 +267,6 @@ def analyze_audio_pair(
252
  # Load & Align
253
  y_near, sr = librosa.load(near_file.name, sr=None)
254
  y_far, _ = librosa.load(far_file.name, sr=sr)
255
-
256
  y_near = librosa.util.normalize(y_near)
257
  y_far = librosa.util.normalize(y_far)
258
  y_near, y_far = align_signals(y_near, y_far)
@@ -260,41 +274,30 @@ def analyze_audio_pair(
260
  # Process
261
  frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
262
  frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
263
-
264
  near_feats = extract_features_with_spectrum(frames_near, sr)
265
  far_feats = extract_features_with_spectrum(frames_far, sr)
266
 
267
- # Comparison Data
268
  comparison_df = compare_frames_enhanced(near_feats, far_feats, comparison_metrics)
269
-
270
- # Clustering Data
271
  near_df_raw = pd.DataFrame(near_feats).drop(columns=["spectrum"], errors="ignore")
272
  far_df_raw = pd.DataFrame(far_feats).drop(columns=["spectrum"], errors="ignore")
273
-
274
- # Perform Dual Clustering
275
  near_clustered, far_clustered = perform_dual_clustering(
276
  near_df_raw, far_df_raw, cluster_features, clustering_algo, n_clusters, dbscan_eps
277
  )
278
 
279
- # 1. Comparison Plot (Dual Axis)
280
  plot_comparison = go.Figure()
281
- # Axis 1: Similarity (0-1)
282
  for col in ["cosine_similarity", "spectral_overlap", "combined_match_score"]:
283
  if col in comparison_df.columns:
284
  plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df[col], name=col, yaxis="y1"))
285
- # Axis 2: dB Loss
286
  if "high_freq_loss_db" in comparison_df.columns:
287
  plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df["high_freq_loss_db"],
288
  name="High Freq Loss (dB)", line=dict(color="red", width=1), yaxis="y2"))
289
-
290
  plot_comparison.update_layout(
291
- title="Comparison Metrics (Dual Axis)",
292
- yaxis=dict(title="Similarity (0-1)", range=[0, 1.1]),
293
- yaxis2=dict(title="Energy Diff (dB)", overlaying="y", side="right"),
294
- legend=dict(x=1.1, y=1)
295
  )
296
 
297
- # 2. Initial Cluster Plot (Near Field)
298
  init_cluster_plot = update_cluster_view("Near Field", near_clustered, far_clustered, cluster_features)
299
 
300
  # 3. Spectral Heatmap
@@ -303,7 +306,7 @@ def analyze_audio_pair(
303
  spec_heatmap = go.Figure(data=go.Heatmap(z=diff, colorscale='RdBu', zmid=0))
304
  spec_heatmap.update_layout(title=f"Spectral Diff (Frame {safe_idx})", height=350)
305
 
306
- # 4. Overlay Plot (Simple)
307
  near_clustered["match_quality"] = comparison_df["combined_match_score"]
308
  if len(cluster_features) > 0:
309
  overlay_fig = px.scatter(near_clustered, x=cluster_features[0], y="match_quality", color="cluster",
@@ -311,11 +314,22 @@ def analyze_audio_pair(
311
  else:
312
  overlay_fig = px.scatter(title="No features")
313
 
314
- # Return: Plots + Dataframes for State + Raw Tables
 
 
 
 
 
 
 
 
 
 
315
  return (plot_comparison, comparison_df,
316
- init_cluster_plot, near_clustered, # Table
317
  spec_heatmap, overlay_fig,
318
- near_clustered, far_clustered) # States
 
319
 
320
  def export_results(comparison_df, near_df, far_df):
321
  temp_dir = tempfile.mkdtemp()
@@ -334,7 +348,6 @@ feature_list = ["rms", "spectral_centroid", "zcr", "spectral_flatness",
334
  "low_freq_energy", "mid_freq_energy", "high_freq_energy"] + [f"mfcc_{i}" for i in range(1, 14)]
335
 
336
  with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
337
- # State storage for interactivity
338
  state_near_df = gr.State()
339
  state_far_df = gr.State()
340
 
@@ -348,13 +361,8 @@ with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
348
  frame_length_ms = gr.Slider(10, 200, value=30, label="Frame Length (ms)")
349
  hop_length_ms = gr.Slider(5, 100, value=15, label="Hop Length (ms)")
350
  window_type = gr.Dropdown(["hann", "hamming"], value="hann", label="Window")
351
-
352
- comparison_metrics = gr.CheckboxGroup(["Cosine Similarity", "High-Freq Loss Ratio"],
353
- value=["Cosine Similarity", "High-Freq Loss Ratio"], label="Metrics")
354
-
355
- cluster_features = gr.CheckboxGroup(feature_list, value=["spectral_centroid", "spectral_flatness"],
356
- label="Clustering Features")
357
-
358
  clustering_algo = gr.Dropdown(["KMeans", "Agglomerative"], value="KMeans", label="Algorithm")
359
  n_clusters = gr.Slider(2, 10, value=4, step=1, label="Clusters")
360
  dbscan_eps = gr.Slider(0.1, 5.0, value=0.5, visible=False)
@@ -365,24 +373,23 @@ with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
365
  with gr.Tab("📈 Comparison"):
366
  comp_plot = gr.Plot()
367
  comp_table = gr.Dataframe()
368
-
369
  with gr.Tab("🧩 Phoneme Clustering"):
370
- with gr.Row():
371
- # TOGGLE SWITCH
372
- view_toggle = gr.Radio(["Near Field", "Far Field"], value="Near Field", label="View Mode")
373
  cluster_plot = gr.Plot()
374
  cluster_table = gr.Dataframe()
375
-
376
  with gr.Tab("🔍 Spectral"):
377
  spec_heatmap = gr.Plot()
378
  with gr.Tab("🧭 Overlay"):
379
  overlay_plot = gr.Plot()
 
 
 
 
380
 
381
  with gr.Tab("📤 Export"):
382
  export_btn = gr.Button("Download CSVs")
383
  export_files = gr.Files()
384
 
385
- # Main Analysis Event
386
  btn.click(
387
  fn=analyze_audio_pair,
388
  inputs=[near_file, far_file, frame_length_ms, hop_length_ms, window_type,
@@ -390,16 +397,11 @@ with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
390
  outputs=[comp_plot, comp_table,
391
  cluster_plot, cluster_table,
392
  spec_heatmap, overlay_plot,
393
- state_near_df, state_far_df] # Save to State
394
- )
395
-
396
- # Toggle Event (Updates plot without re-running analysis)
397
- view_toggle.change(
398
- fn=update_cluster_view,
399
- inputs=[view_toggle, state_near_df, state_far_df, cluster_features],
400
- outputs=[cluster_plot]
401
  )
402
 
 
403
  export_btn.click(fn=export_results, inputs=[comp_table, state_near_df, state_far_df], outputs=export_files)
404
 
405
  if __name__ == "__main__":
 
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from scipy import signal
9
  from scipy.signal import get_window as scipy_get_window
10
+ from scipy.stats import pearsonr
11
  import plotly.express as px
12
  import plotly.graph_objects as go
13
  import os
 
157
  return pd.DataFrame(results)
158
 
159
  # ----------------------------
160
+ # 5. Dual Clustering & Feature Relation (NEW)
161
  # ----------------------------
162
  def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters, eps):
163
+ if not cluster_features: return near_df, far_df
 
 
 
 
 
164
 
165
  valid_features = [f for f in cluster_features if f in near_df.columns]
166
+ if not valid_features: return near_df, far_df
 
167
 
168
+ X_near = np.nan_to_num(near_df[valid_features].values)
169
+ X_far = np.nan_to_num(far_df[valid_features].values)
 
 
 
170
 
 
171
  scaler = StandardScaler()
172
  X_near_scaled = scaler.fit_transform(X_near)
173
+ X_far_scaled = scaler.transform(X_far)
174
 
175
  if algo == "KMeans":
176
  model = KMeans(n_clusters=min(n_clusters, len(X_near)), random_state=42, n_init=10)
177
  near_labels = model.fit_predict(X_near_scaled)
178
+ far_labels = model.predict(X_far_scaled)
179
  elif algo == "Agglomerative":
 
 
180
  model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_near)))
181
  near_labels = model.fit_predict(X_near_scaled)
182
  far_model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_far)))
183
  far_labels = far_model.fit_predict(X_far_scaled)
184
  elif algo == "DBSCAN":
 
185
  model = DBSCAN(eps=eps, min_samples=3)
186
  near_labels = model.fit_predict(X_near_scaled)
187
  far_labels = model.fit_predict(X_far_scaled)
 
189
  near_labels = np.zeros(len(X_near))
190
  far_labels = np.zeros(len(X_far))
191
 
192
+ near_df = near_df.copy(); near_df["cluster"] = near_labels.astype(str)
193
+ far_df = far_df.copy(); far_df["cluster"] = far_labels.astype(str)
 
 
 
 
 
194
 
195
  return near_df, far_df
196
 
197
+ def compute_feature_correlations(near_df, far_df, quality_scores):
198
+ """
199
+ Calculates the correlation between Near Features and Far Features
200
+ weighted by the Match Quality.
201
+ Returns a correlation matrix dataframe for plotting.
202
+ """
203
+ # Filter numeric columns only
204
+ near_num = near_df.select_dtypes(include=[np.number])
205
+ far_num = far_df.select_dtypes(include=[np.number])
206
+
207
+ # We want to see: For a high quality frame, how does Near Feature X relate to Far Feature X?
208
+ # Simple approach: Calculate Pearson Correlation of (Near_Col, Far_Col) across all frames.
209
+
210
+ correlations = {}
211
+
212
+ common_cols = [c for c in near_num.columns if c in far_num.columns]
213
+
214
+ for col in common_cols:
215
+ if col == "cluster": continue
216
+ try:
217
+ # Basic Correlation: Do Near and Far move together?
218
+ corr, _ = pearsonr(near_num[col], far_num[col])
219
+ correlations[col] = corr
220
+ except:
221
+ correlations[col] = 0.0
222
+
223
+ # Also calculate correlation with Quality
224
+ quality_corr = {}
225
+ for col in common_cols:
226
+ if col == "cluster": continue
227
+ try:
228
+ # Does this feature predict high quality?
229
+ # e.g., Does high 'rms' usually mean better match score?
230
+ corr, _ = pearsonr(near_num[col], quality_scores)
231
+ quality_corr[col] = corr
232
+ except:
233
+ quality_corr[col] = 0.0
234
+
235
+ return pd.DataFrame({"Near-Far Correlation": correlations, "Correlation with Quality": quality_corr})
236
+
237
  # ----------------------------
238
  # 6. Plotting Helpers
239
  # ----------------------------
240
  def generate_cluster_plot(df, x_attr, y_attr, title_suffix):
241
  if len(df) == 0 or x_attr not in df.columns or y_attr not in df.columns:
242
  return px.scatter(title="No Data")
 
243
  fig = px.scatter(
244
  df, x=x_attr, y=y_attr, color="cluster",
245
  title=f"Clustering Analysis ({title_suffix}): {x_attr} vs {y_attr}",
246
+ color_discrete_sequence=px.colors.qualitative.Bold
247
  )
248
  return fig
249
 
250
  def update_cluster_view(view_mode, near_df, far_df, cluster_features):
251
+ if near_df is None or far_df is None: return px.scatter(title="Run Analysis First")
252
+ if len(cluster_features) < 2: return px.scatter(title="Select at least 2 features")
 
 
 
 
253
  x_attr, y_attr = cluster_features[0], cluster_features[1]
254
+ if view_mode == "Near Field": return generate_cluster_plot(near_df, x_attr, y_attr, "Near Field")
255
+ else: return generate_cluster_plot(far_df, x_attr, y_attr, "Far Field")
 
 
 
256
 
257
  # ----------------------------
258
  # 7. Main Analysis
 
267
  # Load & Align
268
  y_near, sr = librosa.load(near_file.name, sr=None)
269
  y_far, _ = librosa.load(far_file.name, sr=sr)
 
270
  y_near = librosa.util.normalize(y_near)
271
  y_far = librosa.util.normalize(y_far)
272
  y_near, y_far = align_signals(y_near, y_far)
 
274
  # Process
275
  frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
276
  frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
 
277
  near_feats = extract_features_with_spectrum(frames_near, sr)
278
  far_feats = extract_features_with_spectrum(frames_far, sr)
279
 
280
+ # Compare & Cluster
281
  comparison_df = compare_frames_enhanced(near_feats, far_feats, comparison_metrics)
 
 
282
  near_df_raw = pd.DataFrame(near_feats).drop(columns=["spectrum"], errors="ignore")
283
  far_df_raw = pd.DataFrame(far_feats).drop(columns=["spectrum"], errors="ignore")
 
 
284
  near_clustered, far_clustered = perform_dual_clustering(
285
  near_df_raw, far_df_raw, cluster_features, clustering_algo, n_clusters, dbscan_eps
286
  )
287
 
288
+ # 1. Comparison Plot
289
  plot_comparison = go.Figure()
 
290
  for col in ["cosine_similarity", "spectral_overlap", "combined_match_score"]:
291
  if col in comparison_df.columns:
292
  plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df[col], name=col, yaxis="y1"))
 
293
  if "high_freq_loss_db" in comparison_df.columns:
294
  plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df["high_freq_loss_db"],
295
  name="High Freq Loss (dB)", line=dict(color="red", width=1), yaxis="y2"))
 
296
  plot_comparison.update_layout(
297
+ title="Comparison Metrics", yaxis=dict(title="Similarity"), yaxis2=dict(title="dB Loss", overlaying="y", side="right")
 
 
 
298
  )
299
 
300
+ # 2. Cluster Plot
301
  init_cluster_plot = update_cluster_view("Near Field", near_clustered, far_clustered, cluster_features)
302
 
303
  # 3. Spectral Heatmap
 
306
  spec_heatmap = go.Figure(data=go.Heatmap(z=diff, colorscale='RdBu', zmid=0))
307
  spec_heatmap.update_layout(title=f"Spectral Diff (Frame {safe_idx})", height=350)
308
 
309
+ # 4. Overlay Plot
310
  near_clustered["match_quality"] = comparison_df["combined_match_score"]
311
  if len(cluster_features) > 0:
312
  overlay_fig = px.scatter(near_clustered, x=cluster_features[0], y="match_quality", color="cluster",
 
314
  else:
315
  overlay_fig = px.scatter(title="No features")
316
 
317
+ # 5. NEW: Feature Relation Heatmap
318
+ corr_df = compute_feature_correlations(near_clustered, far_clustered, comparison_df["combined_match_score"])
319
+ corr_fig = px.imshow(corr_df.T, text_auto=True, aspect="auto", color_continuous_scale="RdBu", zmin=-1, zmax=1,
320
+ title="Feature Correlation Analysis")
321
+
322
+ # 6. Scatter Matrix (Inter-feature)
323
+ # Pick top 3 features and Quality
324
+ top_cols = cluster_features[:3] + ["match_quality"]
325
+ scatter_matrix_fig = px.scatter_matrix(near_clustered, dimensions=top_cols, color="cluster",
326
+ title="Inter-Feature Scatter Matrix (Near Field)")
327
+
328
  return (plot_comparison, comparison_df,
329
+ init_cluster_plot, near_clustered,
330
  spec_heatmap, overlay_fig,
331
+ corr_fig, scatter_matrix_fig,
332
+ near_clustered, far_clustered)
333
 
334
  def export_results(comparison_df, near_df, far_df):
335
  temp_dir = tempfile.mkdtemp()
 
348
  "low_freq_energy", "mid_freq_energy", "high_freq_energy"] + [f"mfcc_{i}" for i in range(1, 14)]
349
 
350
  with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
 
351
  state_near_df = gr.State()
352
  state_far_df = gr.State()
353
 
 
361
  frame_length_ms = gr.Slider(10, 200, value=30, label="Frame Length (ms)")
362
  hop_length_ms = gr.Slider(5, 100, value=15, label="Hop Length (ms)")
363
  window_type = gr.Dropdown(["hann", "hamming"], value="hann", label="Window")
364
+ comparison_metrics = gr.CheckboxGroup(["Cosine Similarity", "High-Freq Loss Ratio"], value=["Cosine Similarity", "High-Freq Loss Ratio"], label="Metrics")
365
+ cluster_features = gr.CheckboxGroup(feature_list, value=["spectral_centroid", "spectral_flatness", "rms"], label="Clustering Features")
 
 
 
 
 
366
  clustering_algo = gr.Dropdown(["KMeans", "Agglomerative"], value="KMeans", label="Algorithm")
367
  n_clusters = gr.Slider(2, 10, value=4, step=1, label="Clusters")
368
  dbscan_eps = gr.Slider(0.1, 5.0, value=0.5, visible=False)
 
373
  with gr.Tab("📈 Comparison"):
374
  comp_plot = gr.Plot()
375
  comp_table = gr.Dataframe()
 
376
  with gr.Tab("🧩 Phoneme Clustering"):
377
+ view_toggle = gr.Radio(["Near Field", "Far Field"], value="Near Field", label="View Mode")
 
 
378
  cluster_plot = gr.Plot()
379
  cluster_table = gr.Dataframe()
 
380
  with gr.Tab("🔍 Spectral"):
381
  spec_heatmap = gr.Plot()
382
  with gr.Tab("🧭 Overlay"):
383
  overlay_plot = gr.Plot()
384
+ with gr.Tab("🔗 Feature Relations"):
385
+ gr.Markdown("### Correlation Heatmap & Scatter Matrix")
386
+ corr_plot = gr.Plot(label="Correlation Heatmap")
387
+ scatter_matrix_plot = gr.Plot(label="Scatter Matrix")
388
 
389
  with gr.Tab("📤 Export"):
390
  export_btn = gr.Button("Download CSVs")
391
  export_files = gr.Files()
392
 
 
393
  btn.click(
394
  fn=analyze_audio_pair,
395
  inputs=[near_file, far_file, frame_length_ms, hop_length_ms, window_type,
 
397
  outputs=[comp_plot, comp_table,
398
  cluster_plot, cluster_table,
399
  spec_heatmap, overlay_plot,
400
+ corr_plot, scatter_matrix_plot,
401
+ state_near_df, state_far_df]
 
 
 
 
 
 
402
  )
403
 
404
+ view_toggle.change(fn=update_cluster_view, inputs=[view_toggle, state_near_df, state_far_df, cluster_features], outputs=[cluster_plot])
405
  export_btn.click(fn=export_results, inputs=[comp_table, state_near_df, state_far_df], outputs=export_files)
406
 
407
  if __name__ == "__main__":