AdityaK007 commited on
Commit
d1022e8
·
verified ·
1 Parent(s): 12eadee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -71
app.py CHANGED
@@ -12,6 +12,35 @@ import plotly.graph_objects as go
12
  import os
13
  import tempfile
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ----------------------------
16
  # Enhanced Feature Extraction (with spectral bins)
17
  # ----------------------------
@@ -19,48 +48,88 @@ import tempfile
19
  def extract_features_with_spectrum(frames, sr):
20
  features = []
21
  n_mfcc = 13
22
- n_fft = 2048
23
-
24
  for i in range(frames.shape[1]):
25
  frame = frames[:, i]
 
 
 
 
 
26
  feat = {}
27
 
28
- # Basic features
29
- rms = np.mean(librosa.feature.rms(y=frame)[0])
30
- feat["rms"] = float(rms)
31
-
32
- sc = np.mean(librosa.feature.spectral_centroid(y=frame, sr=sr)[0])
33
- feat["spectral_centroid"] = float(sc)
34
-
35
- zcr = np.mean(librosa.feature.zero_crossing_rate(frame)[0])
36
- feat["zcr"] = float(zcr)
37
-
38
- mfccs = librosa.feature.mfcc(y=frame, sr=sr, n_mfcc=n_mfcc)
39
- for j in range(n_mfcc):
40
- feat[f"mfcc_{j+1}"] = float(np.mean(mfccs[j]))
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Spectral bins for lost frequencies
43
- S = np.abs(librosa.stft(frame, n_fft=n_fft))
44
- S_db = librosa.amplitude_to_db(S, ref=np.max)
45
- freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
46
-
47
- # Split spectrum: low (<2kHz), mid (2-4kHz), high (>4kHz)
48
- low_mask = freqs <= 2000
49
- mid_mask = (freqs > 2000) & (freqs <= 4000)
50
- high_mask = freqs > 4000
51
-
52
- feat["low_freq_energy"] = float(np.mean(S_db[low_mask]))
53
- feat["mid_freq_energy"] = float(np.mean(S_db[mid_mask]))
54
- feat["high_freq_energy"] = float(np.mean(S_db[high_mask]))
55
-
56
- # Store full spectrum for later (optional)
57
- feat["spectrum"] = S_db # will be used for heatmap
 
 
 
 
 
 
58
 
59
  features.append(feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return features
61
 
62
  def compare_frames_enhanced(near_feats, far_feats, metrics):
63
  min_len = min(len(near_feats), len(far_feats))
 
 
 
64
  results = {"frame_index": list(range(min_len))}
65
 
66
  # Prepare vectors
@@ -80,58 +149,79 @@ def compare_frames_enhanced(near_feats, far_feats, metrics):
80
  cos_vals = []
81
  for i in range(min_len):
82
  a, b = near_vec[i].reshape(1, -1), far_vec[i].reshape(1, -1)
83
- cos_vals.append(float(cosine_similarity(a, b)[0][0]))
 
 
 
 
 
 
84
  results["cosine_similarity"] = cos_vals
85
 
86
  # 3. Pearson Correlation
87
  if "Pearson Correlation" in metrics:
88
  corr_vals = []
89
  for i in range(min_len):
90
- corr, _ = pearsonr(near_vec[i], far_vec[i])
91
- corr_vals.append(float(corr) if not np.isnan(corr) else 0.0)
 
 
 
92
  results["pearson_corr"] = corr_vals
93
 
94
  # 4. KL Divergence (on normalized features)
95
  if "KL Divergence" in metrics:
96
  kl_vals = []
97
  for i in range(min_len):
98
- p = near_vec[i] - near_vec[i].min() + 1e-8
99
- q = far_vec[i] - far_vec[i].min() + 1e-8
100
- p /= p.sum()
101
- q /= q.sum()
102
- kl = np.sum(p * np.log(p / q))
103
- kl_vals.append(float(kl))
 
 
 
104
  results["kl_divergence"] = kl_vals
105
 
106
  # 5. Jensen-Shannon Divergence (symmetric, safer)
107
  if "Jensen-Shannon Divergence" in metrics:
108
  js_vals = []
109
  for i in range(min_len):
110
- p = near_vec[i] - near_vec[i].min() + 1e-8
111
- q = far_vec[i] - far_vec[i].min() + 1e-8
112
- p /= p.sum()
113
- q /= q.sum()
114
- js = jensenshannon(p, q)
115
- js_vals.append(float(js))
 
 
 
116
  results["js_divergence"] = js_vals
117
 
118
  # 6. Lost High Frequencies Ratio
119
  if "High-Freq Loss Ratio" in metrics:
120
  loss_ratios = []
121
  for i in range(min_len):
122
- near_high = near_feats[i]["high_freq_energy"]
123
- far_high = far_feats[i]["high_freq_energy"]
124
- # Ratio: how much high-freq energy is lost (positive = loss)
125
- ratio = near_high - far_high # in dB
126
- loss_ratios.append(float(ratio))
 
 
 
127
  results["high_freq_loss_db"] = loss_ratios
128
 
129
  # 7. Spectral Centroid Shift
130
  if "Spectral Centroid Shift" in metrics:
131
  shifts = []
132
  for i in range(min_len):
133
- shift = near_feats[i]["spectral_centroid"] - far_feats[i]["spectral_centroid"]
134
- shifts.append(float(shift))
 
 
 
135
  results["centroid_shift"] = shifts
136
 
137
  return pd.DataFrame(results)
@@ -140,39 +230,58 @@ def cluster_frames_custom(features_df, cluster_features, algo, n_clusters=5, eps
140
  if not cluster_features:
141
  raise gr.Error("Please select at least one feature for clustering.")
142
 
 
 
 
 
143
  X = features_df[cluster_features].values
144
 
145
  if algo == "KMeans":
 
146
  model = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
 
147
  elif algo == "Agglomerative":
 
148
  model = AgglomerativeClustering(n_clusters=n_clusters)
 
149
  elif algo == "DBSCAN":
150
- model = DBSCAN(eps=eps, min_samples=3)
 
 
151
  else:
152
  raise ValueError("Unknown clustering algorithm")
153
 
154
- labels = model.fit_predict(X)
155
  features_df = features_df.copy()
156
  features_df["cluster"] = labels
157
  return features_df
158
 
159
  def plot_spectral_difference(near_feats, far_feats, frame_idx=0):
160
- if frame_idx >= len(near_feats):
161
- frame_idx = 0
 
 
 
 
162
  near_spec = near_feats[frame_idx]["spectrum"]
163
  far_spec = far_feats[frame_idx]["spectrum"]
 
 
 
 
 
 
164
  diff = near_spec - far_spec # positive = energy lost in far-field
165
 
166
  fig = go.Figure(data=go.Heatmap(
167
- z=[diff],
168
  colorscale='RdBu',
169
  zmid=0,
170
  colorbar=dict(title="dB Difference")
171
  ))
172
  fig.update_layout(
173
  title=f"Spectral Difference (Frame {frame_idx}): Near - Far",
174
- xaxis_title="Frequency Bins",
175
- yaxis_title="",
176
  height=300
177
  )
178
  return fig
@@ -196,8 +305,12 @@ def analyze_audio_pair(
196
  if not near_file or not far_file:
197
  raise gr.Error("Upload both audio files.")
198
 
199
- y_near, sr_near = librosa.load_audio(near_file)
200
- y_far, sr_far = librosa.load_audio(far_file)
 
 
 
 
201
 
202
  if sr_near != sr_far:
203
  y_far = librosa.resample(y_far, orig_sr=sr_far, target_sr=sr_near)
@@ -205,7 +318,7 @@ def analyze_audio_pair(
205
  else:
206
  sr = sr_near
207
 
208
- frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
209
  frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
210
 
211
  near_feats = extract_features_with_spectrum(frames_near, sr)
@@ -221,18 +334,24 @@ def analyze_audio_pair(
221
 
222
  # Plots
223
  plot_comparison = None
224
- if comparison_df.shape[1] > 1:
225
- metric_to_plot = [col for col in comparison_df.columns if col != "frame_index"][0]
226
- plot_comparison = px.line(
227
- comparison_df,
228
- x="frame_index",
229
- y=metric_to_plot,
230
- title=f"{metric_to_plot.replace('_', ' ').title()} Over Time"
231
- )
 
 
 
 
 
 
232
 
233
  # Scatter: user-selected features
234
  plot_scatter = None
235
- if len(cluster_features) >= 2:
236
  x_feat, y_feat = cluster_features[0], cluster_features[1]
237
  if x_feat in clustered_df.columns and y_feat in clustered_df.columns:
238
  plot_scatter = px.scatter(
@@ -243,6 +362,8 @@ def analyze_audio_pair(
243
  title=f"Clustering: {x_feat} vs {y_feat}",
244
  hover_data=["cluster"]
245
  )
 
 
246
  else:
247
  plot_scatter = px.scatter(title="Select ≥2 features for scatter plot")
248
 
 
12
  import os
13
  import tempfile
14
 
15
+ # ----------------------------
16
+ # Fixed: Added missing segment_audio function
17
+ # ----------------------------
18
+
19
+ def segment_audio(y, sr, frame_length_ms, hop_length_ms, window_type="hann"):
20
+ """Segment audio into frames with specified windowing"""
21
+ frame_length = int(frame_length_ms * sr / 1000)
22
+ hop_length = int(hop_length_ms * sr / 1000)
23
+
24
+ # Get window function
25
+ if window_type == "rectangular":
26
+ window = scipy_get_window('boxcar', frame_length)
27
+ else:
28
+ window = scipy_get_window(window_type, frame_length)
29
+
30
+ frames = []
31
+ for i in range(0, len(y) - frame_length + 1, hop_length):
32
+ frame = y[i:i + frame_length] * window
33
+ frames.append(frame)
34
+
35
+ # Convert to 2D array (frames x samples)
36
+ if frames:
37
+ frames = np.array(frames).T
38
+ else:
39
+ # If audio is too short, create at least one frame with zero-padding
40
+ frames = np.zeros((frame_length, 1))
41
+
42
+ return frames, frame_length
43
+
44
  # ----------------------------
45
  # Enhanced Feature Extraction (with spectral bins)
46
  # ----------------------------
 
48
  def extract_features_with_spectrum(frames, sr):
49
  features = []
50
  n_mfcc = 13
51
+ n_fft = min(2048, frames.shape[0]) # Fixed: Ensure n_fft <= frame length
52
+
53
  for i in range(frames.shape[1]):
54
  frame = frames[:, i]
55
+
56
+ # Skip if frame is too short or silent
57
+ if len(frame) < n_fft or np.max(np.abs(frame)) < 1e-10:
58
+ continue
59
+
60
  feat = {}
61
 
62
+ # Basic features with error handling
63
+ try:
64
+ rms = np.mean(librosa.feature.rms(y=frame)[0])
65
+ feat["rms"] = float(rms)
66
+ except:
67
+ feat["rms"] = 0.0
68
+
69
+ try:
70
+ sc = np.mean(librosa.feature.spectral_centroid(y=frame, sr=sr)[0])
71
+ feat["spectral_centroid"] = float(sc)
72
+ except:
73
+ feat["spectral_centroid"] = 0.0
74
+
75
+ try:
76
+ zcr = np.mean(librosa.feature.zero_crossing_rate(frame)[0])
77
+ feat["zcr"] = float(zcr)
78
+ except:
79
+ feat["zcr"] = 0.0
80
+
81
+ try:
82
+ mfccs = librosa.feature.mfcc(y=frame, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft)
83
+ for j in range(n_mfcc):
84
+ feat[f"mfcc_{j+1}"] = float(np.mean(mfccs[j]))
85
+ except:
86
+ for j in range(n_mfcc):
87
+ feat[f"mfcc_{j+1}"] = 0.0
88
 
89
  # Spectral bins for lost frequencies
90
+ try:
91
+ S = np.abs(librosa.stft(frame, n_fft=n_fft))
92
+ S_db = librosa.amplitude_to_db(S, ref=np.max)
93
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
94
+
95
+ # Split spectrum: low (<2kHz), mid (2-4kHz), high (>4kHz)
96
+ low_mask = freqs <= 2000
97
+ mid_mask = (freqs > 2000) & (freqs <= 4000)
98
+ high_mask = freqs > 4000
99
+
100
+ feat["low_freq_energy"] = float(np.mean(S_db[low_mask])) if np.any(low_mask) else 0.0
101
+ feat["mid_freq_energy"] = float(np.mean(S_db[mid_mask])) if np.any(mid_mask) else 0.0
102
+ feat["high_freq_energy"] = float(np.mean(S_db[high_mask])) if np.any(high_mask) else 0.0
103
+
104
+ # Store full spectrum for later (optional)
105
+ feat["spectrum"] = S_db # will be used for heatmap
106
+ except:
107
+ feat["low_freq_energy"] = 0.0
108
+ feat["mid_freq_energy"] = 0.0
109
+ feat["high_freq_energy"] = 0.0
110
+ feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1))
111
 
112
  features.append(feat)
113
+
114
+ # Handle case where no features were extracted
115
+ if not features:
116
+ # Create one dummy feature set to avoid errors
117
+ feat = {
118
+ "rms": 0.0, "spectral_centroid": 0.0, "zcr": 0.0,
119
+ "low_freq_energy": 0.0, "mid_freq_energy": 0.0, "high_freq_energy": 0.0,
120
+ "spectrum": np.zeros((n_fft // 2 + 1, 1))
121
+ }
122
+ for j in range(n_mfcc):
123
+ feat[f"mfcc_{j+1}"] = 0.0
124
+ features.append(feat)
125
+
126
  return features
127
 
128
  def compare_frames_enhanced(near_feats, far_feats, metrics):
129
  min_len = min(len(near_feats), len(far_feats))
130
+ if min_len == 0:
131
+ return pd.DataFrame({"frame_index": []})
132
+
133
  results = {"frame_index": list(range(min_len))}
134
 
135
  # Prepare vectors
 
149
  cos_vals = []
150
  for i in range(min_len):
151
  a, b = near_vec[i].reshape(1, -1), far_vec[i].reshape(1, -1)
152
+ # Handle zero vectors
153
+ if np.all(a == 0) and np.all(b == 0):
154
+ cos_vals.append(1.0)
155
+ elif np.all(a == 0) or np.all(b == 0):
156
+ cos_vals.append(0.0)
157
+ else:
158
+ cos_vals.append(float(cosine_similarity(a, b)[0][0]))
159
  results["cosine_similarity"] = cos_vals
160
 
161
  # 3. Pearson Correlation
162
  if "Pearson Correlation" in metrics:
163
  corr_vals = []
164
  for i in range(min_len):
165
+ try:
166
+ corr, _ = pearsonr(near_vec[i], far_vec[i])
167
+ corr_vals.append(float(corr) if not np.isnan(corr) else 0.0)
168
+ except:
169
+ corr_vals.append(0.0)
170
  results["pearson_corr"] = corr_vals
171
 
172
  # 4. KL Divergence (on normalized features)
173
  if "KL Divergence" in metrics:
174
  kl_vals = []
175
  for i in range(min_len):
176
+ try:
177
+ p = near_vec[i] - near_vec[i].min() + 1e-8
178
+ q = far_vec[i] - far_vec[i].min() + 1e-8
179
+ p /= p.sum()
180
+ q /= q.sum()
181
+ kl = np.sum(p * np.log(p / q))
182
+ kl_vals.append(float(kl))
183
+ except:
184
+ kl_vals.append(0.0)
185
  results["kl_divergence"] = kl_vals
186
 
187
  # 5. Jensen-Shannon Divergence (symmetric, safer)
188
  if "Jensen-Shannon Divergence" in metrics:
189
  js_vals = []
190
  for i in range(min_len):
191
+ try:
192
+ p = near_vec[i] - near_vec[i].min() + 1e-8
193
+ q = far_vec[i] - far_vec[i].min() + 1e-8
194
+ p /= p.sum()
195
+ q /= q.sum()
196
+ js = jensenshannon(p, q)
197
+ js_vals.append(float(js))
198
+ except:
199
+ js_vals.append(0.0)
200
  results["js_divergence"] = js_vals
201
 
202
  # 6. Lost High Frequencies Ratio
203
  if "High-Freq Loss Ratio" in metrics:
204
  loss_ratios = []
205
  for i in range(min_len):
206
+ try:
207
+ near_high = near_feats[i]["high_freq_energy"]
208
+ far_high = far_feats[i]["high_freq_energy"]
209
+ # Ratio: how much high-freq energy is lost (positive = loss)
210
+ ratio = near_high - far_high # in dB
211
+ loss_ratios.append(float(ratio))
212
+ except:
213
+ loss_ratios.append(0.0)
214
  results["high_freq_loss_db"] = loss_ratios
215
 
216
  # 7. Spectral Centroid Shift
217
  if "Spectral Centroid Shift" in metrics:
218
  shifts = []
219
  for i in range(min_len):
220
+ try:
221
+ shift = near_feats[i]["spectral_centroid"] - far_feats[i]["spectral_centroid"]
222
+ shifts.append(float(shift))
223
+ except:
224
+ shifts.append(0.0)
225
  results["centroid_shift"] = shifts
226
 
227
  return pd.DataFrame(results)
 
230
  if not cluster_features:
231
  raise gr.Error("Please select at least one feature for clustering.")
232
 
233
+ if len(features_df) == 0:
234
+ features_df["cluster"] = []
235
+ return features_df
236
+
237
  X = features_df[cluster_features].values
238
 
239
  if algo == "KMeans":
240
+ n_clusters = min(n_clusters, len(X)) # Fixed: Cannot have more clusters than samples
241
  model = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
242
+ labels = model.fit_predict(X)
243
  elif algo == "Agglomerative":
244
+ n_clusters = min(n_clusters, len(X))
245
  model = AgglomerativeClustering(n_clusters=n_clusters)
246
+ labels = model.fit_predict(X)
247
  elif algo == "DBSCAN":
248
+ # Fixed: DBSCAN doesn't use n_clusters parameter
249
+ model = DBSCAN(eps=eps, min_samples=min(3, len(X)))
250
+ labels = model.fit_predict(X)
251
  else:
252
  raise ValueError("Unknown clustering algorithm")
253
 
 
254
  features_df = features_df.copy()
255
  features_df["cluster"] = labels
256
  return features_df
257
 
258
  def plot_spectral_difference(near_feats, far_feats, frame_idx=0):
259
+ if not near_feats or not far_feats or frame_idx >= len(near_feats) or frame_idx >= len(far_feats):
260
+ # Return empty plot
261
+ fig = go.Figure()
262
+ fig.update_layout(title="No data available for spectral analysis", height=300)
263
+ return fig
264
+
265
  near_spec = near_feats[frame_idx]["spectrum"]
266
  far_spec = far_feats[frame_idx]["spectrum"]
267
+
268
+ # Ensure both spectrograms have the same shape
269
+ min_freq_bins = min(near_spec.shape[0], far_spec.shape[0])
270
+ near_spec = near_spec[:min_freq_bins]
271
+ far_spec = far_spec[:min_freq_bins]
272
+
273
  diff = near_spec - far_spec # positive = energy lost in far-field
274
 
275
  fig = go.Figure(data=go.Heatmap(
276
+ z=diff, # Fixed: Removed extra list brackets
277
  colorscale='RdBu',
278
  zmid=0,
279
  colorbar=dict(title="dB Difference")
280
  ))
281
  fig.update_layout(
282
  title=f"Spectral Difference (Frame {frame_idx}): Near - Far",
283
+ xaxis_title="Time Frames",
284
+ yaxis_title="Frequency Bins",
285
  height=300
286
  )
287
  return fig
 
305
  if not near_file or not far_file:
306
  raise gr.Error("Upload both audio files.")
307
 
308
+ try:
309
+ # Fixed: Use librosa.load instead of non-existent librosa.load_audio
310
+ y_near, sr_near = librosa.load(near_file.name, sr=None)
311
+ y_far, sr_far = librosa.load(far_file.name, sr=None)
312
+ except Exception as e:
313
+ raise gr.Error(f"Error loading audio files: {str(e)}")
314
 
315
  if sr_near != sr_far:
316
  y_far = librosa.resample(y_far, orig_sr=sr_far, target_sr=sr_near)
 
318
  else:
319
  sr = sr_near
320
 
321
+ frames_near, frame_length = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
322
  frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
323
 
324
  near_feats = extract_features_with_spectrum(frames_near, sr)
 
334
 
335
  # Plots
336
  plot_comparison = None
337
+ if comparison_df.shape[1] > 1 and len(comparison_df) > 0:
338
+ metric_cols = [col for col in comparison_df.columns if col != "frame_index"]
339
+ if metric_cols:
340
+ metric_to_plot = metric_cols[0]
341
+ plot_comparison = px.line(
342
+ comparison_df,
343
+ x="frame_index",
344
+ y=metric_to_plot,
345
+ title=f"{metric_to_plot.replace('_', ' ').title()} Over Time"
346
+ )
347
+ else:
348
+ plot_comparison = px.line(title="No comparison metrics available")
349
+ else:
350
+ plot_comparison = px.line(title="No comparison data available")
351
 
352
  # Scatter: user-selected features
353
  plot_scatter = None
354
+ if len(cluster_features) >= 2 and len(clustered_df) > 0:
355
  x_feat, y_feat = cluster_features[0], cluster_features[1]
356
  if x_feat in clustered_df.columns and y_feat in clustered_df.columns:
357
  plot_scatter = px.scatter(
 
362
  title=f"Clustering: {x_feat} vs {y_feat}",
363
  hover_data=["cluster"]
364
  )
365
+ else:
366
+ plot_scatter = px.scatter(title="Selected features not available in data")
367
  else:
368
  plot_scatter = px.scatter(title="Select ≥2 features for scatter plot")
369