JacobLinCool's picture
Upload folder using huggingface_hub
31bf74c unverified
"""
Visualization utilities for beat tracking evaluation.
This module provides functions to:
- Plot beat and downbeat predictions vs ground truth
- Create waveform visualizations with beat annotations
- Generate comparison plots for evaluation
Example usage:
from exp.data.viz import plot_beats, plot_waveform_with_beats, save_figure
# Plot beat comparison
fig = plot_beats(pred_beats, gt_beats, pred_downbeats, gt_downbeats)
save_figure(fig, "beat_comparison.png")
# Plot waveform with beats
fig = plot_waveform_with_beats(audio, sr, pred_beats, gt_beats)
save_figure(fig, "waveform.png")
"""
import numpy as np
from pathlib import Path
# Try to import matplotlib, but make it optional
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
def _check_matplotlib():
if not HAS_MATPLOTLIB:
raise ImportError(
"matplotlib is required for visualization. "
"Install with: pip install matplotlib"
)
def plot_beats(
pred_beats: list[float] | np.ndarray,
gt_beats: list[float] | np.ndarray,
pred_downbeats: list[float] | np.ndarray | None = None,
gt_downbeats: list[float] | np.ndarray | None = None,
title: str = "Beat Tracking Comparison",
figsize: tuple[int, int] = (14, 4),
time_range: tuple[float, float] | None = None,
) -> "plt.Figure":
"""
Create a visualization comparing predicted and ground truth beats.
Args:
pred_beats: Predicted beat times in seconds
gt_beats: Ground truth beat times in seconds
pred_downbeats: Predicted downbeat times (optional)
gt_downbeats: Ground truth downbeat times (optional)
title: Plot title
figsize: Figure size (width, height)
time_range: Optional tuple (start, end) to limit time range
Returns:
matplotlib Figure object
"""
_check_matplotlib()
fig, ax = plt.subplots(figsize=figsize)
pred_beats = np.array(pred_beats)
gt_beats = np.array(gt_beats)
# Apply time range filter
if time_range is not None:
start, end = time_range
pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
if pred_downbeats is not None:
pred_downbeats = np.array(pred_downbeats)
pred_downbeats = pred_downbeats[
(pred_downbeats >= start) & (pred_downbeats <= end)
]
if gt_downbeats is not None:
gt_downbeats = np.array(gt_downbeats)
gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
# Plot ground truth beats
ax.vlines(
gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5, label="GT Beats"
)
# Plot predicted beats
ax.vlines(
pred_beats,
0.6,
1.0,
colors="blue",
alpha=0.7,
linewidth=1.5,
label="Pred Beats",
)
# Plot downbeats if provided
if gt_downbeats is not None and len(gt_downbeats) > 0:
gt_downbeats = np.array(gt_downbeats)
ax.vlines(
gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3, label="GT Downbeats"
)
if pred_downbeats is not None and len(pred_downbeats) > 0:
pred_downbeats = np.array(pred_downbeats)
ax.vlines(
pred_downbeats,
0.6,
1.0,
colors="darkblue",
linewidth=3,
label="Pred Downbeats",
)
# Styling
ax.set_ylim(-0.1, 1.1)
ax.set_yticks([0.2, 0.8])
ax.set_yticklabels(["Ground Truth", "Prediction"])
ax.set_xlabel("Time (seconds)")
ax.set_title(title)
ax.legend(loc="upper right", ncol=4)
ax.grid(True, alpha=0.3)
# Set x-axis range
if time_range is not None:
ax.set_xlim(time_range)
else:
all_times = np.concatenate([pred_beats, gt_beats])
if len(all_times) > 0:
ax.set_xlim(0, np.max(all_times) + 0.5)
plt.tight_layout()
return fig
def plot_waveform_with_beats(
audio: np.ndarray,
sr: int,
pred_beats: list[float] | np.ndarray,
gt_beats: list[float] | np.ndarray,
pred_downbeats: list[float] | np.ndarray | None = None,
gt_downbeats: list[float] | np.ndarray | None = None,
title: str = "Waveform with Beat Annotations",
figsize: tuple[int, int] = (14, 6),
time_range: tuple[float, float] | None = None,
) -> "plt.Figure":
"""
Create a waveform visualization with beat annotations.
Args:
audio: Audio waveform
sr: Sample rate
pred_beats: Predicted beat times
gt_beats: Ground truth beat times
pred_downbeats: Predicted downbeat times (optional)
gt_downbeats: Ground truth downbeat times (optional)
title: Plot title
figsize: Figure size
time_range: Optional tuple (start, end) to limit time range
Returns:
matplotlib Figure object
"""
_check_matplotlib()
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=figsize, sharex=True, height_ratios=[3, 1]
)
# Time axis
duration = len(audio) / sr
t = np.linspace(0, duration, len(audio))
# Apply time range
if time_range is not None:
start, end = time_range
start_idx = int(start * sr)
end_idx = int(end * sr)
t = t[start_idx:end_idx]
audio_plot = audio[start_idx:end_idx]
else:
audio_plot = audio
start, end = 0, duration
# Plot waveform
ax1.plot(t, audio_plot, color="gray", alpha=0.7, linewidth=0.5)
ax1.set_ylabel("Amplitude")
ax1.set_title(title)
# Filter beats to time range
pred_beats = np.array(pred_beats)
gt_beats = np.array(gt_beats)
pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
# Plot beat markers on waveform
audio_max = np.abs(audio_plot).max() if len(audio_plot) > 0 else 1.0
for beat in gt_beats:
ax1.axvline(beat, color="green", alpha=0.5, linewidth=1)
for beat in pred_beats:
ax1.axvline(beat, color="blue", alpha=0.3, linewidth=1, linestyle="--")
# Add downbeat markers (thicker lines)
if gt_downbeats is not None:
gt_downbeats = np.array(gt_downbeats)
gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
for db in gt_downbeats:
ax1.axvline(db, color="darkgreen", alpha=0.8, linewidth=2)
if pred_downbeats is not None:
pred_downbeats = np.array(pred_downbeats)
pred_downbeats = pred_downbeats[
(pred_downbeats >= start) & (pred_downbeats <= end)
]
for db in pred_downbeats:
ax1.axvline(db, color="darkblue", alpha=0.5, linewidth=2, linestyle="--")
ax1.set_ylim(-audio_max * 1.1, audio_max * 1.1)
# Beat comparison subplot
ax2.vlines(gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5)
ax2.vlines(pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5)
if gt_downbeats is not None and len(gt_downbeats) > 0:
ax2.vlines(gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3)
if pred_downbeats is not None and len(pred_downbeats) > 0:
ax2.vlines(pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3)
ax2.set_ylim(-0.1, 1.1)
ax2.set_yticks([0.2, 0.8])
ax2.set_yticklabels(["GT", "Pred"])
ax2.set_xlabel("Time (seconds)")
# Legend
legend_elements = [
mpatches.Patch(color="green", alpha=0.7, label="GT Beats"),
mpatches.Patch(color="blue", alpha=0.7, label="Pred Beats"),
mpatches.Patch(color="darkgreen", label="GT Downbeats"),
mpatches.Patch(color="darkblue", label="Pred Downbeats"),
]
ax1.legend(handles=legend_elements, loc="upper right", ncol=4)
ax1.grid(True, alpha=0.3)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def plot_evaluation_summary(
results: dict,
title: str = "Evaluation Summary",
figsize: tuple[int, int] = (12, 8),
) -> "plt.Figure":
"""
Create a summary visualization of evaluation results.
Args:
results: Results dict from evaluate_all
title: Plot title
figsize: Figure size
Returns:
matplotlib Figure object
"""
_check_matplotlib()
fig, axes = plt.subplots(2, 2, figsize=figsize)
# F1 by threshold for beats
ax1 = axes[0, 0]
if "beat_f1_by_threshold" in results:
thresholds = sorted(results["beat_f1_by_threshold"].keys())
f1_scores = [results["beat_f1_by_threshold"][t] for t in thresholds]
ax1.bar(range(len(thresholds)), f1_scores, color="steelblue", alpha=0.8)
ax1.set_xticks(range(len(thresholds)))
ax1.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
ax1.set_ylabel("F1 Score")
ax1.set_title("Beat F1 by Threshold")
ax1.set_ylim(0, 1)
ax1.grid(True, alpha=0.3)
# F1 by threshold for downbeats
ax2 = axes[0, 1]
if "downbeat_f1_by_threshold" in results:
thresholds = sorted(results["downbeat_f1_by_threshold"].keys())
f1_scores = [results["downbeat_f1_by_threshold"][t] for t in thresholds]
ax2.bar(range(len(thresholds)), f1_scores, color="coral", alpha=0.8)
ax2.set_xticks(range(len(thresholds)))
ax2.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
ax2.set_ylabel("F1 Score")
ax2.set_title("Downbeat F1 by Threshold")
ax2.set_ylim(0, 1)
ax2.grid(True, alpha=0.3)
# Continuity metrics for beats
ax3 = axes[1, 0]
if "beat_continuity" in results:
metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
values = [results["beat_continuity"][m] for m in metrics]
colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
bars = ax3.bar(metrics, values, color=colors, alpha=0.8)
ax3.set_ylabel("Score")
ax3.set_title("Beat Continuity Metrics")
ax3.set_ylim(0, 1)
ax3.grid(True, alpha=0.3)
# Add value labels
for bar, val in zip(bars, values):
ax3.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{val:.3f}",
ha="center",
fontsize=9,
)
# Continuity metrics for downbeats
ax4 = axes[1, 1]
if "downbeat_continuity" in results:
metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
values = [results["downbeat_continuity"][m] for m in metrics]
colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
bars = ax4.bar(metrics, values, color=colors, alpha=0.8)
ax4.set_ylabel("Score")
ax4.set_title("Downbeat Continuity Metrics")
ax4.set_ylim(0, 1)
ax4.grid(True, alpha=0.3)
# Add value labels
for bar, val in zip(bars, values):
ax4.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{val:.3f}",
ha="center",
fontsize=9,
)
fig.suptitle(title, fontsize=14, fontweight="bold")
plt.tight_layout()
return fig
def save_figure(
fig: "plt.Figure",
path: str | Path,
dpi: int = 150,
) -> None:
"""
Save a matplotlib figure to file.
Args:
fig: Figure to save
path: Output file path
dpi: Resolution (dots per inch)
"""
_check_matplotlib()
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(str(path), dpi=dpi, bbox_inches="tight")
plt.close(fig)
if __name__ == "__main__":
# Demo
_check_matplotlib()
print("Visualization demo...")
# Generate synthetic data
np.random.seed(42)
gt_beats = np.arange(0, 10, 0.5)
gt_downbeats = np.arange(0, 10, 2.0)
pred_beats = gt_beats + np.random.normal(0, 0.02, len(gt_beats))
pred_downbeats = gt_downbeats + np.random.normal(0, 0.01, len(gt_downbeats))
# Generate fake audio
sr = 16000
duration = 10.0
t = np.arange(int(duration * sr)) / sr
audio = np.sin(2 * np.pi * 220 * t) * 0.3
# Create plots
fig1 = plot_beats(
pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Beat Comparison Demo"
)
save_figure(fig1, "/tmp/beat_comparison_demo.png")
print("Saved /tmp/beat_comparison_demo.png")
fig2 = plot_waveform_with_beats(
audio,
sr,
pred_beats,
gt_beats,
pred_downbeats,
gt_downbeats,
title="Waveform Demo",
time_range=(2, 8),
)
save_figure(fig2, "/tmp/waveform_demo.png")
print("Saved /tmp/waveform_demo.png")
# Fake evaluation results
results = {
"beat_f1_by_threshold": {
3: 0.5,
6: 0.7,
9: 0.85,
12: 0.9,
15: 0.95,
18: 0.96,
21: 0.97,
24: 0.97,
27: 0.98,
30: 0.98,
},
"downbeat_f1_by_threshold": {
3: 0.6,
6: 0.8,
9: 0.9,
12: 0.95,
15: 0.97,
18: 0.98,
21: 0.98,
24: 0.99,
27: 0.99,
30: 0.99,
},
"beat_continuity": {"CMLc": 0.75, "CMLt": 0.92, "AMLc": 0.80, "AMLt": 0.95},
"downbeat_continuity": {"CMLc": 0.85, "CMLt": 0.95, "AMLc": 0.88, "AMLt": 0.97},
}
fig3 = plot_evaluation_summary(results, title="Evaluation Summary Demo")
save_figure(fig3, "/tmp/eval_summary_demo.png")
print("Saved /tmp/eval_summary_demo.png")