Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from scipy import signal | |
| from scipy.signal import get_window as scipy_get_window | |
| from scipy.stats import pearsonr | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import os | |
| import tempfile | |
| # ---------------------------- | |
| # 1. Signal Alignment & Preprocessing | |
| # ---------------------------- | |
| def align_signals(ref, target): | |
| """Aligns target signal to reference signal using Cross-Correlation.""" | |
| ref_norm = librosa.util.normalize(ref) | |
| target_norm = librosa.util.normalize(target) | |
| correlation = signal.fftconvolve(target_norm, ref_norm[::-1], mode='full') | |
| lags = signal.correlation_lags(len(target_norm), len(ref_norm), mode='full') | |
| lag = lags[np.argmax(correlation)] | |
| if lag > 0: | |
| aligned_target = target[lag:] | |
| aligned_ref = ref | |
| else: | |
| aligned_target = target | |
| aligned_ref = ref[abs(lag):] | |
| min_len = min(len(aligned_ref), len(aligned_target)) | |
| return aligned_ref[:min_len], aligned_target[:min_len] | |
| # ---------------------------- | |
| # 2. Segment Audio | |
| # ---------------------------- | |
| def segment_audio(y, sr, frame_length_ms, hop_length_ms, window_type="hann"): | |
| frame_length = int(frame_length_ms * sr / 1000) | |
| hop_length = int(hop_length_ms * sr / 1000) | |
| window = scipy_get_window(window_type if window_type != "rectangular" else "boxcar", frame_length) | |
| frames = [] | |
| y_padded = np.pad(y, (0, frame_length), mode='constant') | |
| for i in range(0, len(y) - frame_length + 1, hop_length): | |
| frame = y[i:i + frame_length] * window | |
| frames.append(frame) | |
| if frames: | |
| frames = np.array(frames).T | |
| else: | |
| frames = np.zeros((frame_length, 1)) | |
| return frames, frame_length | |
| # ---------------------------- | |
| # 3. Feature Extraction | |
| # ---------------------------- | |
| def extract_features_with_spectrum(frames, sr): | |
| features = [] | |
| n_mfcc = 13 | |
| n_fft = min(2048, frames.shape[0]) | |
| for i in range(frames.shape[1]): | |
| frame = frames[:, i] | |
| if len(frame) < n_fft or np.max(np.abs(frame)) < 1e-10: | |
| feat = {k: 0.0 for k in ["rms", "spectral_centroid", "zcr", "spectral_flatness", | |
| "low_freq_energy", "mid_freq_energy", "high_freq_energy"]} | |
| for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0 | |
| feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1)) | |
| features.append(feat) | |
| continue | |
| feat = {} | |
| feat["rms"] = float(np.mean(librosa.feature.rms(y=frame)[0])) | |
| feat["zcr"] = float(np.mean(librosa.feature.zero_crossing_rate(frame)[0])) | |
| try: feat["spectral_centroid"] = float(np.mean(librosa.feature.spectral_centroid(y=frame, sr=sr)[0])) | |
| except: feat["spectral_centroid"] = 0.0 | |
| try: feat["spectral_flatness"] = float(np.mean(librosa.feature.spectral_flatness(y=frame)[0])) | |
| except: feat["spectral_flatness"] = 0.0 | |
| try: | |
| mfccs = librosa.feature.mfcc(y=frame, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft) | |
| for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = float(np.mean(mfccs[j])) | |
| except: | |
| for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0 | |
| try: | |
| S = np.abs(librosa.stft(frame, n_fft=n_fft)) | |
| S_db = librosa.amplitude_to_db(S, ref=np.max) | |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) | |
| low_mask = freqs <= 2000 | |
| mid_mask = (freqs > 2000) & (freqs <= 4000) | |
| high_mask = freqs > 4000 | |
| feat["low_freq_energy"] = float(np.mean(S_db[low_mask])) if np.any(low_mask) else -80.0 | |
| feat["mid_freq_energy"] = float(np.mean(S_db[mid_mask])) if np.any(mid_mask) else -80.0 | |
| feat["high_freq_energy"] = float(np.mean(S_db[high_mask])) if np.any(high_mask) else -80.0 | |
| feat["spectrum"] = S_db | |
| except: | |
| feat["low_freq_energy"] = feat["mid_freq_energy"] = feat["high_freq_energy"] = -80.0 | |
| feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1)) | |
| features.append(feat) | |
| return features | |
| # ---------------------------- | |
| # 4. Frame Comparison | |
| # ---------------------------- | |
| def compare_frames_enhanced(near_feats, far_feats, metrics): | |
| min_len = min(len(near_feats), len(far_feats)) | |
| if min_len == 0: return pd.DataFrame({"frame_index": []}) | |
| results = {"frame_index": list(range(min_len))} | |
| near_df = pd.DataFrame(near_feats[:min_len]) | |
| far_df = pd.DataFrame(far_feats[:min_len]) | |
| drop_cols = ["spectrum"] | |
| near_vec = near_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values | |
| far_vec = far_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values | |
| if "Euclidean Distance" in metrics: | |
| results["euclidean_dist"] = np.linalg.norm(near_vec - far_vec, axis=1).tolist() | |
| if "Cosine Similarity" in metrics: | |
| cos_vals = [] | |
| for i in range(min_len): | |
| a, b = near_vec[i].reshape(1, -1), far_vec[i].reshape(1, -1) | |
| if np.all(a == 0) or np.all(b == 0): cos_vals.append(0.0) | |
| else: cos_vals.append(float(cosine_similarity(a, b)[0][0])) | |
| results["cosine_similarity"] = cos_vals | |
| if "High-Freq Loss Ratio" in metrics: | |
| loss_ratios = [] | |
| for i in range(min_len): | |
| loss_ratios.append(float(near_feats[i]["high_freq_energy"] - far_feats[i]["high_freq_energy"])) | |
| results["high_freq_loss_db"] = loss_ratios | |
| overlap_scores = [] | |
| for i in range(min_len): | |
| near_spec = near_feats[i]["spectrum"].flatten() | |
| far_spec = far_feats[i]["spectrum"].flatten() | |
| if np.all(near_spec == 0) or np.all(far_spec == 0): overlap_scores.append(0.0) | |
| else: overlap_scores.append(float(cosine_similarity(near_spec.reshape(1, -1), far_spec.reshape(1, -1))[0][0])) | |
| results["spectral_overlap"] = overlap_scores | |
| combined = [] | |
| for i in range(min_len): | |
| score = (results["spectral_overlap"][i] * 0.5) | |
| if "cosine_similarity" in results: score += (results["cosine_similarity"][i] * 0.5) | |
| combined.append(score) | |
| results["combined_match_score"] = combined | |
| return pd.DataFrame(results) | |
| # ---------------------------- | |
| # 5. Dual Clustering & Feature Relation (NEW) | |
| # ---------------------------- | |
| def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters, eps): | |
| if not cluster_features: return near_df, far_df | |
| valid_features = [f for f in cluster_features if f in near_df.columns] | |
| if not valid_features: return near_df, far_df | |
| X_near = np.nan_to_num(near_df[valid_features].values) | |
| X_far = np.nan_to_num(far_df[valid_features].values) | |
| scaler = StandardScaler() | |
| X_near_scaled = scaler.fit_transform(X_near) | |
| X_far_scaled = scaler.transform(X_far) | |
| if algo == "KMeans": | |
| model = KMeans(n_clusters=min(n_clusters, len(X_near)), random_state=42, n_init=10) | |
| near_labels = model.fit_predict(X_near_scaled) | |
| far_labels = model.predict(X_far_scaled) | |
| elif algo == "Agglomerative": | |
| model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_near))) | |
| near_labels = model.fit_predict(X_near_scaled) | |
| far_model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_far))) | |
| far_labels = far_model.fit_predict(X_far_scaled) | |
| elif algo == "DBSCAN": | |
| model = DBSCAN(eps=eps, min_samples=3) | |
| near_labels = model.fit_predict(X_near_scaled) | |
| far_labels = model.fit_predict(X_far_scaled) | |
| else: | |
| near_labels = np.zeros(len(X_near)) | |
| far_labels = np.zeros(len(X_far)) | |
| near_df = near_df.copy(); near_df["cluster"] = near_labels.astype(str) | |
| far_df = far_df.copy(); far_df["cluster"] = far_labels.astype(str) | |
| return near_df, far_df | |
| def compute_feature_correlations(near_df, far_df, quality_scores): | |
| """ | |
| Calculates the correlation between Near Features and Far Features | |
| weighted by the Match Quality. | |
| Returns a correlation matrix dataframe for plotting. | |
| """ | |
| # Filter numeric columns only | |
| near_num = near_df.select_dtypes(include=[np.number]) | |
| far_num = far_df.select_dtypes(include=[np.number]) | |
| # We want to see: For a high quality frame, how does Near Feature X relate to Far Feature X? | |
| # Simple approach: Calculate Pearson Correlation of (Near_Col, Far_Col) across all frames. | |
| correlations = {} | |
| common_cols = [c for c in near_num.columns if c in far_num.columns] | |
| for col in common_cols: | |
| if col == "cluster": continue | |
| try: | |
| # Basic Correlation: Do Near and Far move together? | |
| corr, _ = pearsonr(near_num[col], far_num[col]) | |
| correlations[col] = corr | |
| except: | |
| correlations[col] = 0.0 | |
| # Also calculate correlation with Quality | |
| quality_corr = {} | |
| for col in common_cols: | |
| if col == "cluster": continue | |
| try: | |
| # Does this feature predict high quality? | |
| # e.g., Does high 'rms' usually mean better match score? | |
| corr, _ = pearsonr(near_num[col], quality_scores) | |
| quality_corr[col] = corr | |
| except: | |
| quality_corr[col] = 0.0 | |
| return pd.DataFrame({"Near-Far Correlation": correlations, "Correlation with Quality": quality_corr}) | |
| # ---------------------------- | |
| # 6. Plotting Helpers | |
| # ---------------------------- | |
| def generate_cluster_plot(df, x_attr, y_attr, title_suffix): | |
| if len(df) == 0 or x_attr not in df.columns or y_attr not in df.columns: | |
| return px.scatter(title="No Data") | |
| fig = px.scatter( | |
| df, x=x_attr, y=y_attr, color="cluster", | |
| title=f"Clustering Analysis ({title_suffix}): {x_attr} vs {y_attr}", | |
| color_discrete_sequence=px.colors.qualitative.Bold | |
| ) | |
| return fig | |
| def update_cluster_view(view_mode, near_df, far_df, cluster_features): | |
| if near_df is None or far_df is None: return px.scatter(title="Run Analysis First") | |
| if len(cluster_features) < 2: return px.scatter(title="Select at least 2 features") | |
| x_attr, y_attr = cluster_features[0], cluster_features[1] | |
| if view_mode == "Near Field": return generate_cluster_plot(near_df, x_attr, y_attr, "Near Field") | |
| else: return generate_cluster_plot(far_df, x_attr, y_attr, "Far Field") | |
| # ---------------------------- | |
| # 7. Main Analysis | |
| # ---------------------------- | |
| def analyze_audio_pair( | |
| near_file, far_file, | |
| frame_length_ms, hop_length_ms, window_type, | |
| comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps | |
| ): | |
| if not near_file or not far_file: raise gr.Error("Upload both files.") | |
| # Load & Align | |
| y_near, sr = librosa.load(near_file.name, sr=None) | |
| y_far, _ = librosa.load(far_file.name, sr=sr) | |
| y_near = librosa.util.normalize(y_near) | |
| y_far = librosa.util.normalize(y_far) | |
| y_near, y_far = align_signals(y_near, y_far) | |
| # Process | |
| frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type) | |
| frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type) | |
| near_feats = extract_features_with_spectrum(frames_near, sr) | |
| far_feats = extract_features_with_spectrum(frames_far, sr) | |
| # Compare & Cluster | |
| comparison_df = compare_frames_enhanced(near_feats, far_feats, comparison_metrics) | |
| near_df_raw = pd.DataFrame(near_feats).drop(columns=["spectrum"], errors="ignore") | |
| far_df_raw = pd.DataFrame(far_feats).drop(columns=["spectrum"], errors="ignore") | |
| near_clustered, far_clustered = perform_dual_clustering( | |
| near_df_raw, far_df_raw, cluster_features, clustering_algo, n_clusters, dbscan_eps | |
| ) | |
| # 1. Comparison Plot | |
| plot_comparison = go.Figure() | |
| for col in ["cosine_similarity", "spectral_overlap", "combined_match_score"]: | |
| if col in comparison_df.columns: | |
| plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df[col], name=col, yaxis="y1")) | |
| if "high_freq_loss_db" in comparison_df.columns: | |
| plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df["high_freq_loss_db"], | |
| name="High Freq Loss (dB)", line=dict(color="red", width=1), yaxis="y2")) | |
| plot_comparison.update_layout( | |
| title="Comparison Metrics", yaxis=dict(title="Similarity"), yaxis2=dict(title="dB Loss", overlaying="y", side="right") | |
| ) | |
| # 2. Cluster Plot | |
| init_cluster_plot = update_cluster_view("Near Field", near_clustered, far_clustered, cluster_features) | |
| # 3. Spectral Heatmap | |
| safe_idx = int(len(near_feats)/2) | |
| diff = near_feats[safe_idx]["spectrum"] - far_feats[safe_idx]["spectrum"] | |
| spec_heatmap = go.Figure(data=go.Heatmap(z=diff, colorscale='RdBu', zmid=0)) | |
| spec_heatmap.update_layout(title=f"Spectral Diff (Frame {safe_idx})", height=350) | |
| # 4. Overlay Plot | |
| near_clustered["match_quality"] = comparison_df["combined_match_score"] | |
| if len(cluster_features) > 0: | |
| overlay_fig = px.scatter(near_clustered, x=cluster_features[0], y="match_quality", color="cluster", | |
| title="Cluster vs Quality (Near Field)") | |
| else: | |
| overlay_fig = px.scatter(title="No features") | |
| # 5. NEW: Feature Relation Heatmap | |
| corr_df = compute_feature_correlations(near_clustered, far_clustered, comparison_df["combined_match_score"]) | |
| corr_fig = px.imshow(corr_df.T, text_auto=True, aspect="auto", color_continuous_scale="RdBu", zmin=-1, zmax=1, | |
| title="Feature Correlation Analysis") | |
| # 6. Scatter Matrix (Inter-feature) | |
| # Pick top 3 features and Quality | |
| top_cols = cluster_features[:3] + ["match_quality"] | |
| scatter_matrix_fig = px.scatter_matrix(near_clustered, dimensions=top_cols, color="cluster", | |
| title="Inter-Feature Scatter Matrix (Near Field)") | |
| return (plot_comparison, comparison_df, | |
| init_cluster_plot, near_clustered, | |
| spec_heatmap, overlay_fig, | |
| corr_fig, scatter_matrix_fig, | |
| near_clustered, far_clustered) | |
| def export_results(comparison_df, near_df, far_df): | |
| temp_dir = tempfile.mkdtemp() | |
| p1 = os.path.join(temp_dir, "comparison.csv") | |
| p2 = os.path.join(temp_dir, "near_clusters.csv") | |
| p3 = os.path.join(temp_dir, "far_clusters.csv") | |
| comparison_df.to_csv(p1, index=False) | |
| near_df.to_csv(p2, index=False) | |
| far_df.to_csv(p3, index=False) | |
| return [p1, p2, p3] | |
| # ---------------------------- | |
| # 8. Gradio UI | |
| # ---------------------------- | |
| feature_list = ["rms", "spectral_centroid", "zcr", "spectral_flatness", | |
| "low_freq_energy", "mid_freq_energy", "high_freq_energy"] + [f"mfcc_{i}" for i in range(1, 14)] | |
| with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo: | |
| state_near_df = gr.State() | |
| state_far_df = gr.State() | |
| gr.Markdown("# ποΈ Near vs Far Field Analyzer (Dual-Clustering)") | |
| with gr.Row(): | |
| near_file = gr.File(label="Near-Field (Ref)", file_types=[".wav"]) | |
| far_file = gr.File(label="Far-Field (Target)", file_types=[".wav"]) | |
| with gr.Accordion("βοΈ Settings", open=False): | |
| frame_length_ms = gr.Slider(10, 200, value=30, label="Frame Length (ms)") | |
| hop_length_ms = gr.Slider(5, 100, value=15, label="Hop Length (ms)") | |
| window_type = gr.Dropdown(["hann", "hamming"], value="hann", label="Window") | |
| comparison_metrics = gr.CheckboxGroup(["Cosine Similarity", "High-Freq Loss Ratio"], value=["Cosine Similarity", "High-Freq Loss Ratio"], label="Metrics") | |
| cluster_features = gr.CheckboxGroup(feature_list, value=["spectral_centroid", "spectral_flatness", "rms"], label="Clustering Features") | |
| clustering_algo = gr.Dropdown(["KMeans", "Agglomerative"], value="KMeans", label="Algorithm") | |
| n_clusters = gr.Slider(2, 10, value=4, step=1, label="Clusters") | |
| dbscan_eps = gr.Slider(0.1, 5.0, value=0.5, visible=False) | |
| btn = gr.Button("π Analyze", variant="primary") | |
| with gr.Tabs(): | |
| with gr.Tab("π Comparison"): | |
| comp_plot = gr.Plot() | |
| comp_table = gr.Dataframe() | |
| with gr.Tab("π§© Phoneme Clustering"): | |
| view_toggle = gr.Radio(["Near Field", "Far Field"], value="Near Field", label="View Mode") | |
| cluster_plot = gr.Plot() | |
| cluster_table = gr.Dataframe() | |
| with gr.Tab("π Spectral"): | |
| spec_heatmap = gr.Plot() | |
| with gr.Tab("π§ Overlay"): | |
| overlay_plot = gr.Plot() | |
| with gr.Tab("π Feature Relations"): | |
| gr.Markdown("### Correlation Heatmap & Scatter Matrix") | |
| corr_plot = gr.Plot(label="Correlation Heatmap") | |
| scatter_matrix_plot = gr.Plot(label="Scatter Matrix") | |
| with gr.Tab("π€ Export"): | |
| export_btn = gr.Button("Download CSVs") | |
| export_files = gr.Files() | |
| btn.click( | |
| fn=analyze_audio_pair, | |
| inputs=[near_file, far_file, frame_length_ms, hop_length_ms, window_type, | |
| comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps], | |
| outputs=[comp_plot, comp_table, | |
| cluster_plot, cluster_table, | |
| spec_heatmap, overlay_plot, | |
| corr_plot, scatter_matrix_plot, | |
| state_near_df, state_far_df] | |
| ) | |
| view_toggle.change(fn=update_cluster_view, inputs=[view_toggle, state_near_df, state_far_df, cluster_features], outputs=[cluster_plot]) | |
| export_btn.click(fn=export_results, inputs=[comp_table, state_near_df, state_far_df], outputs=export_files) | |
| if __name__ == "__main__": | |
| demo.launch() |