MSD / app.py
AdityaK007's picture
Update app.py
85f67e6 verified
import gradio as gr
import librosa
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from scipy import signal
from scipy.signal import get_window as scipy_get_window
from scipy.stats import pearsonr
import plotly.express as px
import plotly.graph_objects as go
import os
import tempfile
# ----------------------------
# 1. Signal Alignment & Preprocessing
# ----------------------------
def align_signals(ref, target):
"""Aligns target signal to reference signal using Cross-Correlation."""
ref_norm = librosa.util.normalize(ref)
target_norm = librosa.util.normalize(target)
correlation = signal.fftconvolve(target_norm, ref_norm[::-1], mode='full')
lags = signal.correlation_lags(len(target_norm), len(ref_norm), mode='full')
lag = lags[np.argmax(correlation)]
if lag > 0:
aligned_target = target[lag:]
aligned_ref = ref
else:
aligned_target = target
aligned_ref = ref[abs(lag):]
min_len = min(len(aligned_ref), len(aligned_target))
return aligned_ref[:min_len], aligned_target[:min_len]
# ----------------------------
# 2. Segment Audio
# ----------------------------
def segment_audio(y, sr, frame_length_ms, hop_length_ms, window_type="hann"):
frame_length = int(frame_length_ms * sr / 1000)
hop_length = int(hop_length_ms * sr / 1000)
window = scipy_get_window(window_type if window_type != "rectangular" else "boxcar", frame_length)
frames = []
y_padded = np.pad(y, (0, frame_length), mode='constant')
for i in range(0, len(y) - frame_length + 1, hop_length):
frame = y[i:i + frame_length] * window
frames.append(frame)
if frames:
frames = np.array(frames).T
else:
frames = np.zeros((frame_length, 1))
return frames, frame_length
# ----------------------------
# 3. Feature Extraction
# ----------------------------
def extract_features_with_spectrum(frames, sr):
features = []
n_mfcc = 13
n_fft = min(2048, frames.shape[0])
for i in range(frames.shape[1]):
frame = frames[:, i]
if len(frame) < n_fft or np.max(np.abs(frame)) < 1e-10:
feat = {k: 0.0 for k in ["rms", "spectral_centroid", "zcr", "spectral_flatness",
"low_freq_energy", "mid_freq_energy", "high_freq_energy"]}
for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0
feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1))
features.append(feat)
continue
feat = {}
feat["rms"] = float(np.mean(librosa.feature.rms(y=frame)[0]))
feat["zcr"] = float(np.mean(librosa.feature.zero_crossing_rate(frame)[0]))
try: feat["spectral_centroid"] = float(np.mean(librosa.feature.spectral_centroid(y=frame, sr=sr)[0]))
except: feat["spectral_centroid"] = 0.0
try: feat["spectral_flatness"] = float(np.mean(librosa.feature.spectral_flatness(y=frame)[0]))
except: feat["spectral_flatness"] = 0.0
try:
mfccs = librosa.feature.mfcc(y=frame, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft)
for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = float(np.mean(mfccs[j]))
except:
for j in range(n_mfcc): feat[f"mfcc_{j+1}"] = 0.0
try:
S = np.abs(librosa.stft(frame, n_fft=n_fft))
S_db = librosa.amplitude_to_db(S, ref=np.max)
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
low_mask = freqs <= 2000
mid_mask = (freqs > 2000) & (freqs <= 4000)
high_mask = freqs > 4000
feat["low_freq_energy"] = float(np.mean(S_db[low_mask])) if np.any(low_mask) else -80.0
feat["mid_freq_energy"] = float(np.mean(S_db[mid_mask])) if np.any(mid_mask) else -80.0
feat["high_freq_energy"] = float(np.mean(S_db[high_mask])) if np.any(high_mask) else -80.0
feat["spectrum"] = S_db
except:
feat["low_freq_energy"] = feat["mid_freq_energy"] = feat["high_freq_energy"] = -80.0
feat["spectrum"] = np.zeros((n_fft // 2 + 1, 1))
features.append(feat)
return features
# ----------------------------
# 4. Frame Comparison
# ----------------------------
def compare_frames_enhanced(near_feats, far_feats, metrics):
min_len = min(len(near_feats), len(far_feats))
if min_len == 0: return pd.DataFrame({"frame_index": []})
results = {"frame_index": list(range(min_len))}
near_df = pd.DataFrame(near_feats[:min_len])
far_df = pd.DataFrame(far_feats[:min_len])
drop_cols = ["spectrum"]
near_vec = near_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values
far_vec = far_df.drop(columns=drop_cols, errors="ignore").select_dtypes(include=[np.number]).values
if "Euclidean Distance" in metrics:
results["euclidean_dist"] = np.linalg.norm(near_vec - far_vec, axis=1).tolist()
if "Cosine Similarity" in metrics:
cos_vals = []
for i in range(min_len):
a, b = near_vec[i].reshape(1, -1), far_vec[i].reshape(1, -1)
if np.all(a == 0) or np.all(b == 0): cos_vals.append(0.0)
else: cos_vals.append(float(cosine_similarity(a, b)[0][0]))
results["cosine_similarity"] = cos_vals
if "High-Freq Loss Ratio" in metrics:
loss_ratios = []
for i in range(min_len):
loss_ratios.append(float(near_feats[i]["high_freq_energy"] - far_feats[i]["high_freq_energy"]))
results["high_freq_loss_db"] = loss_ratios
overlap_scores = []
for i in range(min_len):
near_spec = near_feats[i]["spectrum"].flatten()
far_spec = far_feats[i]["spectrum"].flatten()
if np.all(near_spec == 0) or np.all(far_spec == 0): overlap_scores.append(0.0)
else: overlap_scores.append(float(cosine_similarity(near_spec.reshape(1, -1), far_spec.reshape(1, -1))[0][0]))
results["spectral_overlap"] = overlap_scores
combined = []
for i in range(min_len):
score = (results["spectral_overlap"][i] * 0.5)
if "cosine_similarity" in results: score += (results["cosine_similarity"][i] * 0.5)
combined.append(score)
results["combined_match_score"] = combined
return pd.DataFrame(results)
# ----------------------------
# 5. Dual Clustering & Feature Relation (NEW)
# ----------------------------
def perform_dual_clustering(near_df, far_df, cluster_features, algo, n_clusters, eps):
if not cluster_features: return near_df, far_df
valid_features = [f for f in cluster_features if f in near_df.columns]
if not valid_features: return near_df, far_df
X_near = np.nan_to_num(near_df[valid_features].values)
X_far = np.nan_to_num(far_df[valid_features].values)
scaler = StandardScaler()
X_near_scaled = scaler.fit_transform(X_near)
X_far_scaled = scaler.transform(X_far)
if algo == "KMeans":
model = KMeans(n_clusters=min(n_clusters, len(X_near)), random_state=42, n_init=10)
near_labels = model.fit_predict(X_near_scaled)
far_labels = model.predict(X_far_scaled)
elif algo == "Agglomerative":
model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_near)))
near_labels = model.fit_predict(X_near_scaled)
far_model = AgglomerativeClustering(n_clusters=min(n_clusters, len(X_far)))
far_labels = far_model.fit_predict(X_far_scaled)
elif algo == "DBSCAN":
model = DBSCAN(eps=eps, min_samples=3)
near_labels = model.fit_predict(X_near_scaled)
far_labels = model.fit_predict(X_far_scaled)
else:
near_labels = np.zeros(len(X_near))
far_labels = np.zeros(len(X_far))
near_df = near_df.copy(); near_df["cluster"] = near_labels.astype(str)
far_df = far_df.copy(); far_df["cluster"] = far_labels.astype(str)
return near_df, far_df
def compute_feature_correlations(near_df, far_df, quality_scores):
"""
Calculates the correlation between Near Features and Far Features
weighted by the Match Quality.
Returns a correlation matrix dataframe for plotting.
"""
# Filter numeric columns only
near_num = near_df.select_dtypes(include=[np.number])
far_num = far_df.select_dtypes(include=[np.number])
# We want to see: For a high quality frame, how does Near Feature X relate to Far Feature X?
# Simple approach: Calculate Pearson Correlation of (Near_Col, Far_Col) across all frames.
correlations = {}
common_cols = [c for c in near_num.columns if c in far_num.columns]
for col in common_cols:
if col == "cluster": continue
try:
# Basic Correlation: Do Near and Far move together?
corr, _ = pearsonr(near_num[col], far_num[col])
correlations[col] = corr
except:
correlations[col] = 0.0
# Also calculate correlation with Quality
quality_corr = {}
for col in common_cols:
if col == "cluster": continue
try:
# Does this feature predict high quality?
# e.g., Does high 'rms' usually mean better match score?
corr, _ = pearsonr(near_num[col], quality_scores)
quality_corr[col] = corr
except:
quality_corr[col] = 0.0
return pd.DataFrame({"Near-Far Correlation": correlations, "Correlation with Quality": quality_corr})
# ----------------------------
# 6. Plotting Helpers
# ----------------------------
def generate_cluster_plot(df, x_attr, y_attr, title_suffix):
if len(df) == 0 or x_attr not in df.columns or y_attr not in df.columns:
return px.scatter(title="No Data")
fig = px.scatter(
df, x=x_attr, y=y_attr, color="cluster",
title=f"Clustering Analysis ({title_suffix}): {x_attr} vs {y_attr}",
color_discrete_sequence=px.colors.qualitative.Bold
)
return fig
def update_cluster_view(view_mode, near_df, far_df, cluster_features):
if near_df is None or far_df is None: return px.scatter(title="Run Analysis First")
if len(cluster_features) < 2: return px.scatter(title="Select at least 2 features")
x_attr, y_attr = cluster_features[0], cluster_features[1]
if view_mode == "Near Field": return generate_cluster_plot(near_df, x_attr, y_attr, "Near Field")
else: return generate_cluster_plot(far_df, x_attr, y_attr, "Far Field")
# ----------------------------
# 7. Main Analysis
# ----------------------------
def analyze_audio_pair(
near_file, far_file,
frame_length_ms, hop_length_ms, window_type,
comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps
):
if not near_file or not far_file: raise gr.Error("Upload both files.")
# Load & Align
y_near, sr = librosa.load(near_file.name, sr=None)
y_far, _ = librosa.load(far_file.name, sr=sr)
y_near = librosa.util.normalize(y_near)
y_far = librosa.util.normalize(y_far)
y_near, y_far = align_signals(y_near, y_far)
# Process
frames_near, _ = segment_audio(y_near, sr, frame_length_ms, hop_length_ms, window_type)
frames_far, _ = segment_audio(y_far, sr, frame_length_ms, hop_length_ms, window_type)
near_feats = extract_features_with_spectrum(frames_near, sr)
far_feats = extract_features_with_spectrum(frames_far, sr)
# Compare & Cluster
comparison_df = compare_frames_enhanced(near_feats, far_feats, comparison_metrics)
near_df_raw = pd.DataFrame(near_feats).drop(columns=["spectrum"], errors="ignore")
far_df_raw = pd.DataFrame(far_feats).drop(columns=["spectrum"], errors="ignore")
near_clustered, far_clustered = perform_dual_clustering(
near_df_raw, far_df_raw, cluster_features, clustering_algo, n_clusters, dbscan_eps
)
# 1. Comparison Plot
plot_comparison = go.Figure()
for col in ["cosine_similarity", "spectral_overlap", "combined_match_score"]:
if col in comparison_df.columns:
plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df[col], name=col, yaxis="y1"))
if "high_freq_loss_db" in comparison_df.columns:
plot_comparison.add_trace(go.Scatter(x=comparison_df["frame_index"], y=comparison_df["high_freq_loss_db"],
name="High Freq Loss (dB)", line=dict(color="red", width=1), yaxis="y2"))
plot_comparison.update_layout(
title="Comparison Metrics", yaxis=dict(title="Similarity"), yaxis2=dict(title="dB Loss", overlaying="y", side="right")
)
# 2. Cluster Plot
init_cluster_plot = update_cluster_view("Near Field", near_clustered, far_clustered, cluster_features)
# 3. Spectral Heatmap
safe_idx = int(len(near_feats)/2)
diff = near_feats[safe_idx]["spectrum"] - far_feats[safe_idx]["spectrum"]
spec_heatmap = go.Figure(data=go.Heatmap(z=diff, colorscale='RdBu', zmid=0))
spec_heatmap.update_layout(title=f"Spectral Diff (Frame {safe_idx})", height=350)
# 4. Overlay Plot
near_clustered["match_quality"] = comparison_df["combined_match_score"]
if len(cluster_features) > 0:
overlay_fig = px.scatter(near_clustered, x=cluster_features[0], y="match_quality", color="cluster",
title="Cluster vs Quality (Near Field)")
else:
overlay_fig = px.scatter(title="No features")
# 5. NEW: Feature Relation Heatmap
corr_df = compute_feature_correlations(near_clustered, far_clustered, comparison_df["combined_match_score"])
corr_fig = px.imshow(corr_df.T, text_auto=True, aspect="auto", color_continuous_scale="RdBu", zmin=-1, zmax=1,
title="Feature Correlation Analysis")
# 6. Scatter Matrix (Inter-feature)
# Pick top 3 features and Quality
top_cols = cluster_features[:3] + ["match_quality"]
scatter_matrix_fig = px.scatter_matrix(near_clustered, dimensions=top_cols, color="cluster",
title="Inter-Feature Scatter Matrix (Near Field)")
return (plot_comparison, comparison_df,
init_cluster_plot, near_clustered,
spec_heatmap, overlay_fig,
corr_fig, scatter_matrix_fig,
near_clustered, far_clustered)
def export_results(comparison_df, near_df, far_df):
temp_dir = tempfile.mkdtemp()
p1 = os.path.join(temp_dir, "comparison.csv")
p2 = os.path.join(temp_dir, "near_clusters.csv")
p3 = os.path.join(temp_dir, "far_clusters.csv")
comparison_df.to_csv(p1, index=False)
near_df.to_csv(p2, index=False)
far_df.to_csv(p3, index=False)
return [p1, p2, p3]
# ----------------------------
# 8. Gradio UI
# ----------------------------
feature_list = ["rms", "spectral_centroid", "zcr", "spectral_flatness",
"low_freq_energy", "mid_freq_energy", "high_freq_energy"] + [f"mfcc_{i}" for i in range(1, 14)]
with gr.Blocks(title="Audio Field Analyzer", theme=gr.themes.Soft()) as demo:
state_near_df = gr.State()
state_far_df = gr.State()
gr.Markdown("# πŸŽ™οΈ Near vs Far Field Analyzer (Dual-Clustering)")
with gr.Row():
near_file = gr.File(label="Near-Field (Ref)", file_types=[".wav"])
far_file = gr.File(label="Far-Field (Target)", file_types=[".wav"])
with gr.Accordion("βš™οΈ Settings", open=False):
frame_length_ms = gr.Slider(10, 200, value=30, label="Frame Length (ms)")
hop_length_ms = gr.Slider(5, 100, value=15, label="Hop Length (ms)")
window_type = gr.Dropdown(["hann", "hamming"], value="hann", label="Window")
comparison_metrics = gr.CheckboxGroup(["Cosine Similarity", "High-Freq Loss Ratio"], value=["Cosine Similarity", "High-Freq Loss Ratio"], label="Metrics")
cluster_features = gr.CheckboxGroup(feature_list, value=["spectral_centroid", "spectral_flatness", "rms"], label="Clustering Features")
clustering_algo = gr.Dropdown(["KMeans", "Agglomerative"], value="KMeans", label="Algorithm")
n_clusters = gr.Slider(2, 10, value=4, step=1, label="Clusters")
dbscan_eps = gr.Slider(0.1, 5.0, value=0.5, visible=False)
btn = gr.Button("πŸš€ Analyze", variant="primary")
with gr.Tabs():
with gr.Tab("πŸ“ˆ Comparison"):
comp_plot = gr.Plot()
comp_table = gr.Dataframe()
with gr.Tab("🧩 Phoneme Clustering"):
view_toggle = gr.Radio(["Near Field", "Far Field"], value="Near Field", label="View Mode")
cluster_plot = gr.Plot()
cluster_table = gr.Dataframe()
with gr.Tab("πŸ” Spectral"):
spec_heatmap = gr.Plot()
with gr.Tab("🧭 Overlay"):
overlay_plot = gr.Plot()
with gr.Tab("πŸ”— Feature Relations"):
gr.Markdown("### Correlation Heatmap & Scatter Matrix")
corr_plot = gr.Plot(label="Correlation Heatmap")
scatter_matrix_plot = gr.Plot(label="Scatter Matrix")
with gr.Tab("πŸ“€ Export"):
export_btn = gr.Button("Download CSVs")
export_files = gr.Files()
btn.click(
fn=analyze_audio_pair,
inputs=[near_file, far_file, frame_length_ms, hop_length_ms, window_type,
comparison_metrics, cluster_features, clustering_algo, n_clusters, dbscan_eps],
outputs=[comp_plot, comp_table,
cluster_plot, cluster_table,
spec_heatmap, overlay_plot,
corr_plot, scatter_matrix_plot,
state_near_df, state_far_df]
)
view_toggle.change(fn=update_cluster_view, inputs=[view_toggle, state_near_df, state_far_df, cluster_features], outputs=[cluster_plot])
export_btn.click(fn=export_results, inputs=[comp_table, state_near_df, state_far_df], outputs=export_files)
if __name__ == "__main__":
demo.launch()