AdityaK007 commited on
Commit
16d47ae
·
verified ·
1 Parent(s): 7528c16

Update app_good.py

Browse files
Files changed (1) hide show
  1. app_good.py +406 -0
app_good.py CHANGED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from scipy import signal
9
+ from scipy.signal import get_window as scipy_get_window
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ import os
13
+ import tempfile
14
+
15
+ # ----------------------------
16
+ # 1. Signal Alignment & Preprocessing
17
+ # ----------------------------
18
+ def align_signals(ref, target):
19
+ """Aligns target signal to reference signal using Cross-Correlation."""
20
+ ref_norm = librosa.util.normalize(ref)
21
+ target_norm = librosa.util.normalize(target)
22
+
23
+ correlation = signal.fftconvolve(target_norm, ref_norm[::-1], mode='full')
24
+ lags = signal.correlation_lags(len(target_norm), len(ref_norm), mode='full')
25
+ lag = lags[np.argmax(correlation)]
26
+
27
+ if lag > 0:
28
+ aligned_target = target[lag:]
29
+ aligned_ref = ref
30
+ else:
31
+ aligned_target = target
32
+ aligned_ref = ref[abs(lag):]
33
+
34
+ min_len = min(len(aligned_ref), len(aligned_target))
35
+ return aligned_ref[:min_len], aligned_target[:min_len]
36
+
37
+ # ----------------------------
38
+ # 2. Segment Audio
39
+ # ----------------------------
40
+ def segment_audio(y, sr, frame_length_ms, hop_length_ms, window_type="hann"):
41
+ frame_length = int(frame_length_ms * sr / 1000)
42
+ hop_length = int(hop_length_ms * sr / 1000)
43
+ window = scipy_get_window(window_type if window_type != "rectangular" else "boxcar", frame_length)
44
+ frames = []
45
+ y_padded = np.pad(y, (0, frame_length), mode='constant')
46
+
47
+ for i in range(0, len(y) - frame_length + 1, hop_length):
48
+ frame = y[i:i + frame_length] * window
49
+ frames.append(frame)
50
+
51
+ if frames:
52
+ frames = np.array(frames).T
53
+ else:
54
+ frames = np.zeros((frame_length, 1))
55
+ return frames, frame_length
56
+
57
+ # ----------------------------
58
+ # 3. Feature Extraction
59
+ # ----------------------------
60
+ def extract_features_with_spectrum(frames, sr):
61
+ features = []
62
+ n_mfcc = 13
63
+ n_fft = min(2048, frames.shape[0])
64
+
65
+ for i in range(frames.shape[1]):
66
+ frame = frames[:, i]
67
+ if len(frame) < n_fft or np.max(np.abs(frame)) < 1e-10:
68
+ feat = {k: 0.0 for k in ["rms", "spectral_centroid", "zcr", "spectral_flatness",
69
+ "low_freq_energy", "mid_freq_energy", "high_freq_energy"]}
70
+ for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0
71
+ feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1))
72
+ features.append(feat)
73
+ continue
74
+
75
+ feat = {}
76
+ feat["rms"] = float(np.mean(librosa.feature.rms(y=frame)[0]))
77
+ feat["zcr"] = float(np.mean(librosa.feature.zero_crossing_rate(frame)[0]))
78
+
79
+ try: feat["spectral_centroid"] = float(np.mean(librosa.feature.spectral_centroid(y=frame, sr=sr)[0]))
80
+ except: feat["spectral_centroid"] = 0.0
81
+
82
+ try: feat["spectral_flatness"] = float(np.mean(librosa.feature.spectral_flatness(y=frame)[0]))
83
+ except: feat["spectral_flatness"] = 0.0
84
+
85
+ try:
86
+ mfccs = librosa.feature.mfcc(y=frame, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft)
87
+ for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = float(np.mean(mfccs[j]))
88
+ except:
89
+ for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0
90
+
91
+ try:
92
+ S = np.abs(librosa.stft(frame, n_fft=n_fft))
93
+ S_db = librosa.amplitude_to_db(S, ref=np.max)
94
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
95
+ low_mask = freqs <= 2000
96
+ mid_mask = (freqs > 2000) & (freqs <= 4000)
97
+ high_mask = freqs > 4000
98
+ feat["low_freq_energy"] = float(np.mean(S_db[low_mask])) if np.any(low_mask) else -80.0
99
+ feat["mid_freq_energy"] = float(np.mean(S_db[mid_mask])) if np.any(mid_mask) else -80.0
100
+ feat["high_freq_energy"] = float(np.mean(S_db[high_mask])) if np.any(high_mask) else -80.0
101
+ feat["spectrum"] = S_db
102
+ except:
103
+ feat["low_freq_energy"] = feat["mid_freq_energy"] = feat["high_freq_energy"] = -80.0
104
+ feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1))
105
+
106
+ features.append(feat)
107
+ return features
108
+
109
+ # ----------------------------
110
+ # 4. Frame Comparison
111
+ # ----------------------------
112
+ def compare_frames_enhanced(near_feats, far_feats, metrics):
113
+ min_len = min(len(near_feats), len(far_feats))
114
+ if min_len == 0: return pd.DataFrame({"frame_index": []})
115
+
116
+ results = {"frame_index": list(range(min_len))}
117
+ near_df = pd.DataFrame(near_feats[:min_len])
118
+ far_df = pd.DataFrame(far_feats[:min_len])
119
+
120
+ drop_cols = ["spectrum"]
121
+ near_vec = near_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values
122
+ far_vec = far_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values
123
+
124
+ if "Euclidean Distance" in metrics:
125
+ results["euclidean_dist"] = np.linalg.norm(near_vec - far_vec, axis=1).tolist()
126
+
127
+ if "Cosine Similarity" in metrics:
128
+ cos_vals = []
129
+ for i in range(min_len):
130
+ a, b = near_vec[i].reshape(1, -1), far_vec[i].reshape(1, -1)
131
+ if np.all(a == 0) or np.all(b == 0): cos_vals.append(0.0)
132
+ else: cos_vals.append(float(cosine_similarity(a, b)[0][0]))
133
+ results["cosine_similarity"] = cos_vals
134
+
135
+ if "High-Freq Loss Ratio" in metrics:
136
+ loss_ratios = []
137
+ for i in range(min_len):
138
+ loss_ratios.append(float(near_feats[i]["high_freq_energy"] - far_feats[i]["high_freq_energy"]))
139
+ results["high_freq_loss_db"] = loss_ratios
140
+
141
+ overlap_scores = []
142
+ for i in range(min_len):
143
+ near_spec = near_feats[i]["spectrum"].flatten()
144
+ far_spec = far_feats[i]["spectrum"].flatten()
145
+ if np.all(near_spec == 0) or np.all(far_spec == 0): overlap_scores.append(0.0)
146
+ else: overlap_scores.append(float(cosine_similarity(near_spec.reshape(1, -1), far_spec.reshape(1, -1))[0][0]))
147
+ results["spectral_overlap"] = overlap_scores
148
+
149
+ combined = []
150
+ for i in range(min_len):
151
+ score = (results["spectral_overlap"][i] * 0.5)
152
+ if "cosine_similarity" in results: score += (results["cosine_similarity"][i] * 0.5)
153
+ combined.append(score)
154
+ results["combined_match_score"] = combined
155
+
156
+ return pd.DataFrame(results)
157
+
158
+ # ----------------------------
159
+ # 5. Dual Clustering Logic
160
+ # ----------------------------
161
+ def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters, eps):
162
+ """
163
+ Fits clustering on Near Field (clean), then predicts on Far Field (noisy).
164
+ This ensures Cluster 0 in Near corresponds to the same physical sound in Far.
165
+ """
166
+ if not cluster_features:
167
+ return near_df, far_df
168
+
169
+ valid_features = [f for f in cluster_features if f in near_df.columns]
170
+ if not valid_features:
171
+ return near_df, far_df
172
+
173
+ X_near = near_df[valid_features].values
174
+ X_near = np.nan_to_num(X_near)
175
+
176
+ X_far = far_df[valid_features].values
177
+ X_far = np.nan_to_num(X_far)
178
+
179
+ # We use a Scaler to ensure features are comparable
180
+ scaler = StandardScaler()
181
+ X_near_scaled = scaler.fit_transform(X_near)
182
+ X_far_scaled = scaler.transform(X_far) # Use same scaler for Far
183
+
184
+ if algo == "KMeans":
185
+ model = KMeans(n_clusters=min(n_clusters, len(X_near)), random_state=42, n_init=10)
186
+ near_labels = model.fit_predict(X_near_scaled)
187
+ far_labels = model.predict(X_far_scaled) # Predict using Near model
188
+ elif algo == "Agglomerative":
189
+ # Agglomerative cannot "predict" on new data easily, so we cluster independently
190
+ # This is a limitation, but acceptable fallback
191
+ model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_near)))
192
+ near_labels = model.fit_predict(X_near_scaled)
193
+ far_model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_far)))
194
+ far_labels = far_model.fit_predict(X_far_scaled)
195
+ elif algo == "DBSCAN":
196
+ # DBSCAN also cannot "predict", must fit_predict.
197
+ model = DBSCAN(eps=eps, min_samples=3)
198
+ near_labels = model.fit_predict(X_near_scaled)
199
+ far_labels = model.fit_predict(X_far_scaled)
200
+ else:
201
+ near_labels = np.zeros(len(X_near))
202
+ far_labels = np.zeros(len(X_far))
203
+
204
+ near_df = near_df.copy()
205
+ near_df["cluster"] = near_labels
206
+ near_df["cluster"] = near_df["cluster"].astype(str) # For categorical coloring
207
+
208
+ far_df = far_df.copy()
209
+ far_df["cluster"] = far_labels
210
+ far_df["cluster"] = far_df["cluster"].astype(str)
211
+
212
+ return near_df, far_df
213
+
214
+ # ----------------------------
215
+ # 6. Plotting Helpers
216
+ # ----------------------------
217
+ def generate_cluster_plot(df, x_attr, y_attr, title_suffix):
218
+ if len(df) == 0 or x_attr not in df.columns or y_attr not in df.columns:
219
+ return px.scatter(title="No Data")
220
+
221
+ fig = px.scatter(
222
+ df, x=x_attr, y=y_attr, color="cluster",
223
+ title=f"Clustering Analysis ({title_suffix}): {x_attr} vs {y_attr}",
224
+ color_discrete_sequence=px.colors.qualitative.Bold # Consistent colors
225
+ )
226
+ return fig
227
+
228
+ def update_cluster_view(view_mode, near_df, far_df, cluster_features):
229
+ if near_df is None or far_df is None:
230
+ return px.scatter(title="Run Analysis First")
231
+
232
+ if len(cluster_features) < 2:
233
+ return px.scatter(title="Select at least 2 features")
234
+
235
+ x_attr, y_attr = cluster_features[0], cluster_features[1]
236
+
237
+ if view_mode == "Near Field":
238
+ return generate_cluster_plot(near_df, x_attr, y_attr, "Near Field")
239
+ else:
240
+ return generate_cluster_plot(far_df, x_attr, y_attr, "Far Field")
241
+
242
+ # ----------------------------
243
+ # 7. Main Analysis
244
+ # ----------------------------
245
+ def analyze_audio_pair(
246
+ near_file, far_file,
247
+ frame_length_ms, hop_length_ms, window_type,
248
+ comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps
249
+ ):
250
+ if not near_file or not far_file: raise gr.Error("Upload both files.")
251
+
252
+ # Load & Align
253
+ y_near, sr = librosa.load(near_file.name, sr=None)
254
+ y_far, _ = librosa.load(far_file.name, sr=sr)
255
+
256
+ y_near = librosa.util.normalize(y_near)
257
+ y_far = librosa.util.normalize(y_far)
258
+ y_near, y_far = align_signals(y_near, y_far)
259
+
260
+ # Process
261
+ frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
262
+ frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
263
+
264
+ near_feats = extract_features_with_spectrum(frames_near, sr)
265
+ far_feats = extract_features_with_spectrum(frames_far, sr)
266
+
267
+ # Comparison Data
268
+ comparison_df = compare_frames_enhanced(near_feats, far_feats, comparison_metrics)
269
+
270
+ # Clustering Data
271
+ near_df_raw = pd.DataFrame(near_feats).drop(columns=["spectrum"], errors="ignore")
272
+ far_df_raw = pd.DataFrame(far_feats).drop(columns=["spectrum"], errors="ignore")
273
+
274
+ # Perform Dual Clustering
275
+ near_clustered, far_clustered = perform_dual_clustering(
276
+ near_df_raw, far_df_raw, cluster_features, clustering_algo, n_clusters, dbscan_eps
277
+ )
278
+
279
+ # 1. Comparison Plot (Dual Axis)
280
+ plot_comparison = go.Figure()
281
+ # Axis 1: Similarity (0-1)
282
+ for col in ["cosine_similarity", "spectral_overlap", "combined_match_score"]:
283
+ if col in comparison_df.columns:
284
+ plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df[col], name=col, yaxis="y1"))
285
+ # Axis 2: dB Loss
286
+ if "high_freq_loss_db" in comparison_df.columns:
287
+ plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df["high_freq_loss_db"],
288
+ name="High Freq Loss (dB)", line=dict(color="red", width=1), yaxis="y2"))
289
+
290
+ plot_comparison.update_layout(
291
+ title="Comparison Metrics (Dual Axis)",
292
+ yaxis=dict(title="Similarity (0-1)", range=[0, 1.1]),
293
+ yaxis2=dict(title="Energy Diff (dB)", overlaying="y", side="right"),
294
+ legend=dict(x=1.1, y=1)
295
+ )
296
+
297
+ # 2. Initial Cluster Plot (Near Field)
298
+ init_cluster_plot = update_cluster_view("Near Field", near_clustered, far_clustered, cluster_features)
299
+
300
+ # 3. Spectral Heatmap
301
+ safe_idx = int(len(near_feats)/2)
302
+ diff = near_feats[safe_idx]["spectrum"] - far_feats[safe_idx]["spectrum"]
303
+ spec_heatmap = go.Figure(data=go.Heatmap(z=diff, colorscale='RdBu', zmid=0))
304
+ spec_heatmap.update_layout(title=f"Spectral Diff (Frame {safe_idx})", height=350)
305
+
306
+ # 4. Overlay Plot (Simple)
307
+ near_clustered["match_quality"] = comparison_df["combined_match_score"]
308
+ if len(cluster_features) > 0:
309
+ overlay_fig = px.scatter(near_clustered, x=cluster_features[0], y="match_quality", color="cluster",
310
+ title="Cluster vs Quality (Near Field)")
311
+ else:
312
+ overlay_fig = px.scatter(title="No features")
313
+
314
+ # Return: Plots + Dataframes for State + Raw Tables
315
+ return (plot_comparison, comparison_df,
316
+ init_cluster_plot, near_clustered, # Table
317
+ spec_heatmap, overlay_fig,
318
+ near_clustered, far_clustered) # States
319
+
320
+ def export_results(comparison_df, near_df, far_df):
321
+ temp_dir = tempfile.mkdtemp()
322
+ p1 = os.path.join(temp_dir, "comparison.csv")
323
+ p2 = os.path.join(temp_dir, "near_clusters.csv")
324
+ p3 = os.path.join(temp_dir, "far_clusters.csv")
325
+ comparison_df.to_csv(p1, index=False)
326
+ near_df.to_csv(p2, index=False)
327
+ far_df.to_csv(p3, index=False)
328
+ return [p1, p2, p3]
329
+
330
+ # ----------------------------
331
+ # 8. Gradio UI
332
+ # ----------------------------
333
+ feature_list = ["rms", "spectral_centroid", "zcr", "spectral_flatness",
334
+ "low_freq_energy", "mid_freq_energy", "high_freq_energy"] + [f"mfcc_{i}" for i in range(1, 14)]
335
+
336
+ with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
337
+ # State storage for interactivity
338
+ state_near_df = gr.State()
339
+ state_far_df = gr.State()
340
+
341
+ gr.Markdown("# 🎙️ Near vs Far Field Analyzer (Dual-Clustering)")
342
+
343
+ with gr.Row():
344
+ near_file = gr.File(label="Near-Field (Ref)", file_types=[".wav"])
345
+ far_file = gr.File(label="Far-Field (Target)", file_types=[".wav"])
346
+
347
+ with gr.Accordion("⚙️ Settings", open=False):
348
+ frame_length_ms = gr.Slider(10, 200, value=30, label="Frame Length (ms)")
349
+ hop_length_ms = gr.Slider(5, 100, value=15, label="Hop Length (ms)")
350
+ window_type = gr.Dropdown(["hann", "hamming"], value="hann", label="Window")
351
+
352
+ comparison_metrics = gr.CheckboxGroup(["Cosine Similarity", "High-Freq Loss Ratio"],
353
+ value=["Cosine Similarity", "High-Freq Loss Ratio"], label="Metrics")
354
+
355
+ cluster_features = gr.CheckboxGroup(feature_list, value=["spectral_centroid", "spectral_flatness"],
356
+ label="Clustering Features")
357
+
358
+ clustering_algo = gr.Dropdown(["KMeans", "Agglomerative"], value="KMeans", label="Algorithm")
359
+ n_clusters = gr.Slider(2, 10, value=4, step=1, label="Clusters")
360
+ dbscan_eps = gr.Slider(0.1, 5.0, value=0.5, visible=False)
361
+
362
+ btn = gr.Button("🚀 Analyze", variant="primary")
363
+
364
+ with gr.Tabs():
365
+ with gr.Tab("📈 Comparison"):
366
+ comp_plot = gr.Plot()
367
+ comp_table = gr.Dataframe()
368
+
369
+ with gr.Tab("🧩 Phoneme Clustering"):
370
+ with gr.Row():
371
+ # TOGGLE SWITCH
372
+ view_toggle = gr.Radio(["Near Field", "Far Field"], value="Near Field", label="View Mode")
373
+ cluster_plot = gr.Plot()
374
+ cluster_table = gr.Dataframe()
375
+
376
+ with gr.Tab("🔍 Spectral"):
377
+ spec_heatmap = gr.Plot()
378
+ with gr.Tab("🧭 Overlay"):
379
+ overlay_plot = gr.Plot()
380
+
381
+ with gr.Tab("📤 Export"):
382
+ export_btn = gr.Button("Download CSVs")
383
+ export_files = gr.Files()
384
+
385
+ # Main Analysis Event
386
+ btn.click(
387
+ fn=analyze_audio_pair,
388
+ inputs=[near_file, far_file, frame_length_ms, hop_length_ms, window_type,
389
+ comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps],
390
+ outputs=[comp_plot, comp_table,
391
+ cluster_plot, cluster_table,
392
+ spec_heatmap, overlay_plot,
393
+ state_near_df, state_far_df] # Save to State
394
+ )
395
+
396
+ # Toggle Event (Updates plot without re-running analysis)
397
+ view_toggle.change(
398
+ fn=update_cluster_view,
399
+ inputs=[view_toggle, state_near_df, state_far_df, cluster_features],
400
+ outputs=[cluster_plot]
401
+ )
402
+
403
+ export_btn.click(fn=export_results, inputs=[comp_table, state_near_df, state_far_df], outputs=export_files)
404
+
405
+ if __name__ == "__main__":
406
+ demo.launch()