|
|
""" |
|
|
Visualization utilities for beat tracking evaluation. |
|
|
|
|
|
This module provides functions to: |
|
|
- Plot beat and downbeat predictions vs ground truth |
|
|
- Create waveform visualizations with beat annotations |
|
|
- Generate comparison plots for evaluation |
|
|
|
|
|
Example usage: |
|
|
from exp.data.viz import plot_beats, plot_waveform_with_beats, save_figure |
|
|
|
|
|
# Plot beat comparison |
|
|
fig = plot_beats(pred_beats, gt_beats, pred_downbeats, gt_downbeats) |
|
|
save_figure(fig, "beat_comparison.png") |
|
|
|
|
|
# Plot waveform with beats |
|
|
fig = plot_waveform_with_beats(audio, sr, pred_beats, gt_beats) |
|
|
save_figure(fig, "waveform.png") |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as mpatches |
|
|
|
|
|
HAS_MATPLOTLIB = True |
|
|
except ImportError: |
|
|
HAS_MATPLOTLIB = False |
|
|
|
|
|
|
|
|
def _check_matplotlib(): |
|
|
if not HAS_MATPLOTLIB: |
|
|
raise ImportError( |
|
|
"matplotlib is required for visualization. " |
|
|
"Install with: pip install matplotlib" |
|
|
) |
|
|
|
|
|
|
|
|
def plot_beats( |
|
|
pred_beats: list[float] | np.ndarray, |
|
|
gt_beats: list[float] | np.ndarray, |
|
|
pred_downbeats: list[float] | np.ndarray | None = None, |
|
|
gt_downbeats: list[float] | np.ndarray | None = None, |
|
|
title: str = "Beat Tracking Comparison", |
|
|
figsize: tuple[int, int] = (14, 4), |
|
|
time_range: tuple[float, float] | None = None, |
|
|
) -> "plt.Figure": |
|
|
""" |
|
|
Create a visualization comparing predicted and ground truth beats. |
|
|
|
|
|
Args: |
|
|
pred_beats: Predicted beat times in seconds |
|
|
gt_beats: Ground truth beat times in seconds |
|
|
pred_downbeats: Predicted downbeat times (optional) |
|
|
gt_downbeats: Ground truth downbeat times (optional) |
|
|
title: Plot title |
|
|
figsize: Figure size (width, height) |
|
|
time_range: Optional tuple (start, end) to limit time range |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure object |
|
|
""" |
|
|
_check_matplotlib() |
|
|
|
|
|
fig, ax = plt.subplots(figsize=figsize) |
|
|
|
|
|
pred_beats = np.array(pred_beats) |
|
|
gt_beats = np.array(gt_beats) |
|
|
|
|
|
|
|
|
if time_range is not None: |
|
|
start, end = time_range |
|
|
pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)] |
|
|
gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)] |
|
|
|
|
|
if pred_downbeats is not None: |
|
|
pred_downbeats = np.array(pred_downbeats) |
|
|
pred_downbeats = pred_downbeats[ |
|
|
(pred_downbeats >= start) & (pred_downbeats <= end) |
|
|
] |
|
|
if gt_downbeats is not None: |
|
|
gt_downbeats = np.array(gt_downbeats) |
|
|
gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)] |
|
|
|
|
|
|
|
|
ax.vlines( |
|
|
gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5, label="GT Beats" |
|
|
) |
|
|
|
|
|
|
|
|
ax.vlines( |
|
|
pred_beats, |
|
|
0.6, |
|
|
1.0, |
|
|
colors="blue", |
|
|
alpha=0.7, |
|
|
linewidth=1.5, |
|
|
label="Pred Beats", |
|
|
) |
|
|
|
|
|
|
|
|
if gt_downbeats is not None and len(gt_downbeats) > 0: |
|
|
gt_downbeats = np.array(gt_downbeats) |
|
|
ax.vlines( |
|
|
gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3, label="GT Downbeats" |
|
|
) |
|
|
|
|
|
if pred_downbeats is not None and len(pred_downbeats) > 0: |
|
|
pred_downbeats = np.array(pred_downbeats) |
|
|
ax.vlines( |
|
|
pred_downbeats, |
|
|
0.6, |
|
|
1.0, |
|
|
colors="darkblue", |
|
|
linewidth=3, |
|
|
label="Pred Downbeats", |
|
|
) |
|
|
|
|
|
|
|
|
ax.set_ylim(-0.1, 1.1) |
|
|
ax.set_yticks([0.2, 0.8]) |
|
|
ax.set_yticklabels(["Ground Truth", "Prediction"]) |
|
|
ax.set_xlabel("Time (seconds)") |
|
|
ax.set_title(title) |
|
|
ax.legend(loc="upper right", ncol=4) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if time_range is not None: |
|
|
ax.set_xlim(time_range) |
|
|
else: |
|
|
all_times = np.concatenate([pred_beats, gt_beats]) |
|
|
if len(all_times) > 0: |
|
|
ax.set_xlim(0, np.max(all_times) + 0.5) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_waveform_with_beats( |
|
|
audio: np.ndarray, |
|
|
sr: int, |
|
|
pred_beats: list[float] | np.ndarray, |
|
|
gt_beats: list[float] | np.ndarray, |
|
|
pred_downbeats: list[float] | np.ndarray | None = None, |
|
|
gt_downbeats: list[float] | np.ndarray | None = None, |
|
|
title: str = "Waveform with Beat Annotations", |
|
|
figsize: tuple[int, int] = (14, 6), |
|
|
time_range: tuple[float, float] | None = None, |
|
|
) -> "plt.Figure": |
|
|
""" |
|
|
Create a waveform visualization with beat annotations. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform |
|
|
sr: Sample rate |
|
|
pred_beats: Predicted beat times |
|
|
gt_beats: Ground truth beat times |
|
|
pred_downbeats: Predicted downbeat times (optional) |
|
|
gt_downbeats: Ground truth downbeat times (optional) |
|
|
title: Plot title |
|
|
figsize: Figure size |
|
|
time_range: Optional tuple (start, end) to limit time range |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure object |
|
|
""" |
|
|
_check_matplotlib() |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots( |
|
|
2, 1, figsize=figsize, sharex=True, height_ratios=[3, 1] |
|
|
) |
|
|
|
|
|
|
|
|
duration = len(audio) / sr |
|
|
t = np.linspace(0, duration, len(audio)) |
|
|
|
|
|
|
|
|
if time_range is not None: |
|
|
start, end = time_range |
|
|
start_idx = int(start * sr) |
|
|
end_idx = int(end * sr) |
|
|
t = t[start_idx:end_idx] |
|
|
audio_plot = audio[start_idx:end_idx] |
|
|
else: |
|
|
audio_plot = audio |
|
|
start, end = 0, duration |
|
|
|
|
|
|
|
|
ax1.plot(t, audio_plot, color="gray", alpha=0.7, linewidth=0.5) |
|
|
ax1.set_ylabel("Amplitude") |
|
|
ax1.set_title(title) |
|
|
|
|
|
|
|
|
pred_beats = np.array(pred_beats) |
|
|
gt_beats = np.array(gt_beats) |
|
|
pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)] |
|
|
gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)] |
|
|
|
|
|
|
|
|
audio_max = np.abs(audio_plot).max() if len(audio_plot) > 0 else 1.0 |
|
|
|
|
|
for beat in gt_beats: |
|
|
ax1.axvline(beat, color="green", alpha=0.5, linewidth=1) |
|
|
for beat in pred_beats: |
|
|
ax1.axvline(beat, color="blue", alpha=0.3, linewidth=1, linestyle="--") |
|
|
|
|
|
|
|
|
if gt_downbeats is not None: |
|
|
gt_downbeats = np.array(gt_downbeats) |
|
|
gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)] |
|
|
for db in gt_downbeats: |
|
|
ax1.axvline(db, color="darkgreen", alpha=0.8, linewidth=2) |
|
|
|
|
|
if pred_downbeats is not None: |
|
|
pred_downbeats = np.array(pred_downbeats) |
|
|
pred_downbeats = pred_downbeats[ |
|
|
(pred_downbeats >= start) & (pred_downbeats <= end) |
|
|
] |
|
|
for db in pred_downbeats: |
|
|
ax1.axvline(db, color="darkblue", alpha=0.5, linewidth=2, linestyle="--") |
|
|
|
|
|
ax1.set_ylim(-audio_max * 1.1, audio_max * 1.1) |
|
|
|
|
|
|
|
|
ax2.vlines(gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5) |
|
|
ax2.vlines(pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5) |
|
|
|
|
|
if gt_downbeats is not None and len(gt_downbeats) > 0: |
|
|
ax2.vlines(gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3) |
|
|
if pred_downbeats is not None and len(pred_downbeats) > 0: |
|
|
ax2.vlines(pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3) |
|
|
|
|
|
ax2.set_ylim(-0.1, 1.1) |
|
|
ax2.set_yticks([0.2, 0.8]) |
|
|
ax2.set_yticklabels(["GT", "Pred"]) |
|
|
ax2.set_xlabel("Time (seconds)") |
|
|
|
|
|
|
|
|
legend_elements = [ |
|
|
mpatches.Patch(color="green", alpha=0.7, label="GT Beats"), |
|
|
mpatches.Patch(color="blue", alpha=0.7, label="Pred Beats"), |
|
|
mpatches.Patch(color="darkgreen", label="GT Downbeats"), |
|
|
mpatches.Patch(color="darkblue", label="Pred Downbeats"), |
|
|
] |
|
|
ax1.legend(handles=legend_elements, loc="upper right", ncol=4) |
|
|
|
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_evaluation_summary( |
|
|
results: dict, |
|
|
title: str = "Evaluation Summary", |
|
|
figsize: tuple[int, int] = (12, 8), |
|
|
) -> "plt.Figure": |
|
|
""" |
|
|
Create a summary visualization of evaluation results. |
|
|
|
|
|
Args: |
|
|
results: Results dict from evaluate_all |
|
|
title: Plot title |
|
|
figsize: Figure size |
|
|
|
|
|
Returns: |
|
|
matplotlib Figure object |
|
|
""" |
|
|
_check_matplotlib() |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=figsize) |
|
|
|
|
|
|
|
|
ax1 = axes[0, 0] |
|
|
if "beat_f1_by_threshold" in results: |
|
|
thresholds = sorted(results["beat_f1_by_threshold"].keys()) |
|
|
f1_scores = [results["beat_f1_by_threshold"][t] for t in thresholds] |
|
|
ax1.bar(range(len(thresholds)), f1_scores, color="steelblue", alpha=0.8) |
|
|
ax1.set_xticks(range(len(thresholds))) |
|
|
ax1.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45) |
|
|
ax1.set_ylabel("F1 Score") |
|
|
ax1.set_title("Beat F1 by Threshold") |
|
|
ax1.set_ylim(0, 1) |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax2 = axes[0, 1] |
|
|
if "downbeat_f1_by_threshold" in results: |
|
|
thresholds = sorted(results["downbeat_f1_by_threshold"].keys()) |
|
|
f1_scores = [results["downbeat_f1_by_threshold"][t] for t in thresholds] |
|
|
ax2.bar(range(len(thresholds)), f1_scores, color="coral", alpha=0.8) |
|
|
ax2.set_xticks(range(len(thresholds))) |
|
|
ax2.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45) |
|
|
ax2.set_ylabel("F1 Score") |
|
|
ax2.set_title("Downbeat F1 by Threshold") |
|
|
ax2.set_ylim(0, 1) |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax3 = axes[1, 0] |
|
|
if "beat_continuity" in results: |
|
|
metrics = ["CMLc", "CMLt", "AMLc", "AMLt"] |
|
|
values = [results["beat_continuity"][m] for m in metrics] |
|
|
colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"] |
|
|
bars = ax3.bar(metrics, values, color=colors, alpha=0.8) |
|
|
ax3.set_ylabel("Score") |
|
|
ax3.set_title("Beat Continuity Metrics") |
|
|
ax3.set_ylim(0, 1) |
|
|
ax3.grid(True, alpha=0.3) |
|
|
|
|
|
for bar, val in zip(bars, values): |
|
|
ax3.text( |
|
|
bar.get_x() + bar.get_width() / 2, |
|
|
bar.get_height() + 0.02, |
|
|
f"{val:.3f}", |
|
|
ha="center", |
|
|
fontsize=9, |
|
|
) |
|
|
|
|
|
|
|
|
ax4 = axes[1, 1] |
|
|
if "downbeat_continuity" in results: |
|
|
metrics = ["CMLc", "CMLt", "AMLc", "AMLt"] |
|
|
values = [results["downbeat_continuity"][m] for m in metrics] |
|
|
colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"] |
|
|
bars = ax4.bar(metrics, values, color=colors, alpha=0.8) |
|
|
ax4.set_ylabel("Score") |
|
|
ax4.set_title("Downbeat Continuity Metrics") |
|
|
ax4.set_ylim(0, 1) |
|
|
ax4.grid(True, alpha=0.3) |
|
|
|
|
|
for bar, val in zip(bars, values): |
|
|
ax4.text( |
|
|
bar.get_x() + bar.get_width() / 2, |
|
|
bar.get_height() + 0.02, |
|
|
f"{val:.3f}", |
|
|
ha="center", |
|
|
fontsize=9, |
|
|
) |
|
|
|
|
|
fig.suptitle(title, fontsize=14, fontweight="bold") |
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def save_figure( |
|
|
fig: "plt.Figure", |
|
|
path: str | Path, |
|
|
dpi: int = 150, |
|
|
) -> None: |
|
|
""" |
|
|
Save a matplotlib figure to file. |
|
|
|
|
|
Args: |
|
|
fig: Figure to save |
|
|
path: Output file path |
|
|
dpi: Resolution (dots per inch) |
|
|
""" |
|
|
_check_matplotlib() |
|
|
|
|
|
path = Path(path) |
|
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
|
fig.savefig(str(path), dpi=dpi, bbox_inches="tight") |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
_check_matplotlib() |
|
|
print("Visualization demo...") |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
gt_beats = np.arange(0, 10, 0.5) |
|
|
gt_downbeats = np.arange(0, 10, 2.0) |
|
|
pred_beats = gt_beats + np.random.normal(0, 0.02, len(gt_beats)) |
|
|
pred_downbeats = gt_downbeats + np.random.normal(0, 0.01, len(gt_downbeats)) |
|
|
|
|
|
|
|
|
sr = 16000 |
|
|
duration = 10.0 |
|
|
t = np.arange(int(duration * sr)) / sr |
|
|
audio = np.sin(2 * np.pi * 220 * t) * 0.3 |
|
|
|
|
|
|
|
|
fig1 = plot_beats( |
|
|
pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Beat Comparison Demo" |
|
|
) |
|
|
save_figure(fig1, "/tmp/beat_comparison_demo.png") |
|
|
print("Saved /tmp/beat_comparison_demo.png") |
|
|
|
|
|
fig2 = plot_waveform_with_beats( |
|
|
audio, |
|
|
sr, |
|
|
pred_beats, |
|
|
gt_beats, |
|
|
pred_downbeats, |
|
|
gt_downbeats, |
|
|
title="Waveform Demo", |
|
|
time_range=(2, 8), |
|
|
) |
|
|
save_figure(fig2, "/tmp/waveform_demo.png") |
|
|
print("Saved /tmp/waveform_demo.png") |
|
|
|
|
|
|
|
|
results = { |
|
|
"beat_f1_by_threshold": { |
|
|
3: 0.5, |
|
|
6: 0.7, |
|
|
9: 0.85, |
|
|
12: 0.9, |
|
|
15: 0.95, |
|
|
18: 0.96, |
|
|
21: 0.97, |
|
|
24: 0.97, |
|
|
27: 0.98, |
|
|
30: 0.98, |
|
|
}, |
|
|
"downbeat_f1_by_threshold": { |
|
|
3: 0.6, |
|
|
6: 0.8, |
|
|
9: 0.9, |
|
|
12: 0.95, |
|
|
15: 0.97, |
|
|
18: 0.98, |
|
|
21: 0.98, |
|
|
24: 0.99, |
|
|
27: 0.99, |
|
|
30: 0.99, |
|
|
}, |
|
|
"beat_continuity": {"CMLc": 0.75, "CMLt": 0.92, "AMLc": 0.80, "AMLt": 0.95}, |
|
|
"downbeat_continuity": {"CMLc": 0.85, "CMLt": 0.95, "AMLc": 0.88, "AMLt": 0.97}, |
|
|
} |
|
|
fig3 = plot_evaluation_summary(results, title="Evaluation Summary Demo") |
|
|
save_figure(fig3, "/tmp/eval_summary_demo.png") |
|
|
print("Saved /tmp/eval_summary_demo.png") |
|
|
|