""" Visualization utilities for beat tracking evaluation. This module provides functions to: - Plot beat and downbeat predictions vs ground truth - Create waveform visualizations with beat annotations - Generate comparison plots for evaluation Example usage: from exp.data.viz import plot_beats, plot_waveform_with_beats, save_figure # Plot beat comparison fig = plot_beats(pred_beats, gt_beats, pred_downbeats, gt_downbeats) save_figure(fig, "beat_comparison.png") # Plot waveform with beats fig = plot_waveform_with_beats(audio, sr, pred_beats, gt_beats) save_figure(fig, "waveform.png") """ import numpy as np from pathlib import Path # Try to import matplotlib, but make it optional try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches HAS_MATPLOTLIB = True except ImportError: HAS_MATPLOTLIB = False def _check_matplotlib(): if not HAS_MATPLOTLIB: raise ImportError( "matplotlib is required for visualization. " "Install with: pip install matplotlib" ) def plot_beats( pred_beats: list[float] | np.ndarray, gt_beats: list[float] | np.ndarray, pred_downbeats: list[float] | np.ndarray | None = None, gt_downbeats: list[float] | np.ndarray | None = None, title: str = "Beat Tracking Comparison", figsize: tuple[int, int] = (14, 4), time_range: tuple[float, float] | None = None, ) -> "plt.Figure": """ Create a visualization comparing predicted and ground truth beats. Args: pred_beats: Predicted beat times in seconds gt_beats: Ground truth beat times in seconds pred_downbeats: Predicted downbeat times (optional) gt_downbeats: Ground truth downbeat times (optional) title: Plot title figsize: Figure size (width, height) time_range: Optional tuple (start, end) to limit time range Returns: matplotlib Figure object """ _check_matplotlib() fig, ax = plt.subplots(figsize=figsize) pred_beats = np.array(pred_beats) gt_beats = np.array(gt_beats) # Apply time range filter if time_range is not None: start, end = time_range pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)] gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)] if pred_downbeats is not None: pred_downbeats = np.array(pred_downbeats) pred_downbeats = pred_downbeats[ (pred_downbeats >= start) & (pred_downbeats <= end) ] if gt_downbeats is not None: gt_downbeats = np.array(gt_downbeats) gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)] # Plot ground truth beats ax.vlines( gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5, label="GT Beats" ) # Plot predicted beats ax.vlines( pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5, label="Pred Beats", ) # Plot downbeats if provided if gt_downbeats is not None and len(gt_downbeats) > 0: gt_downbeats = np.array(gt_downbeats) ax.vlines( gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3, label="GT Downbeats" ) if pred_downbeats is not None and len(pred_downbeats) > 0: pred_downbeats = np.array(pred_downbeats) ax.vlines( pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3, label="Pred Downbeats", ) # Styling ax.set_ylim(-0.1, 1.1) ax.set_yticks([0.2, 0.8]) ax.set_yticklabels(["Ground Truth", "Prediction"]) ax.set_xlabel("Time (seconds)") ax.set_title(title) ax.legend(loc="upper right", ncol=4) ax.grid(True, alpha=0.3) # Set x-axis range if time_range is not None: ax.set_xlim(time_range) else: all_times = np.concatenate([pred_beats, gt_beats]) if len(all_times) > 0: ax.set_xlim(0, np.max(all_times) + 0.5) plt.tight_layout() return fig def plot_waveform_with_beats( audio: np.ndarray, sr: int, pred_beats: list[float] | np.ndarray, gt_beats: list[float] | np.ndarray, pred_downbeats: list[float] | np.ndarray | None = None, gt_downbeats: list[float] | np.ndarray | None = None, title: str = "Waveform with Beat Annotations", figsize: tuple[int, int] = (14, 6), time_range: tuple[float, float] | None = None, ) -> "plt.Figure": """ Create a waveform visualization with beat annotations. Args: audio: Audio waveform sr: Sample rate pred_beats: Predicted beat times gt_beats: Ground truth beat times pred_downbeats: Predicted downbeat times (optional) gt_downbeats: Ground truth downbeat times (optional) title: Plot title figsize: Figure size time_range: Optional tuple (start, end) to limit time range Returns: matplotlib Figure object """ _check_matplotlib() fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=figsize, sharex=True, height_ratios=[3, 1] ) # Time axis duration = len(audio) / sr t = np.linspace(0, duration, len(audio)) # Apply time range if time_range is not None: start, end = time_range start_idx = int(start * sr) end_idx = int(end * sr) t = t[start_idx:end_idx] audio_plot = audio[start_idx:end_idx] else: audio_plot = audio start, end = 0, duration # Plot waveform ax1.plot(t, audio_plot, color="gray", alpha=0.7, linewidth=0.5) ax1.set_ylabel("Amplitude") ax1.set_title(title) # Filter beats to time range pred_beats = np.array(pred_beats) gt_beats = np.array(gt_beats) pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)] gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)] # Plot beat markers on waveform audio_max = np.abs(audio_plot).max() if len(audio_plot) > 0 else 1.0 for beat in gt_beats: ax1.axvline(beat, color="green", alpha=0.5, linewidth=1) for beat in pred_beats: ax1.axvline(beat, color="blue", alpha=0.3, linewidth=1, linestyle="--") # Add downbeat markers (thicker lines) if gt_downbeats is not None: gt_downbeats = np.array(gt_downbeats) gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)] for db in gt_downbeats: ax1.axvline(db, color="darkgreen", alpha=0.8, linewidth=2) if pred_downbeats is not None: pred_downbeats = np.array(pred_downbeats) pred_downbeats = pred_downbeats[ (pred_downbeats >= start) & (pred_downbeats <= end) ] for db in pred_downbeats: ax1.axvline(db, color="darkblue", alpha=0.5, linewidth=2, linestyle="--") ax1.set_ylim(-audio_max * 1.1, audio_max * 1.1) # Beat comparison subplot ax2.vlines(gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5) ax2.vlines(pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5) if gt_downbeats is not None and len(gt_downbeats) > 0: ax2.vlines(gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3) if pred_downbeats is not None and len(pred_downbeats) > 0: ax2.vlines(pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3) ax2.set_ylim(-0.1, 1.1) ax2.set_yticks([0.2, 0.8]) ax2.set_yticklabels(["GT", "Pred"]) ax2.set_xlabel("Time (seconds)") # Legend legend_elements = [ mpatches.Patch(color="green", alpha=0.7, label="GT Beats"), mpatches.Patch(color="blue", alpha=0.7, label="Pred Beats"), mpatches.Patch(color="darkgreen", label="GT Downbeats"), mpatches.Patch(color="darkblue", label="Pred Downbeats"), ] ax1.legend(handles=legend_elements, loc="upper right", ncol=4) ax1.grid(True, alpha=0.3) ax2.grid(True, alpha=0.3) plt.tight_layout() return fig def plot_evaluation_summary( results: dict, title: str = "Evaluation Summary", figsize: tuple[int, int] = (12, 8), ) -> "plt.Figure": """ Create a summary visualization of evaluation results. Args: results: Results dict from evaluate_all title: Plot title figsize: Figure size Returns: matplotlib Figure object """ _check_matplotlib() fig, axes = plt.subplots(2, 2, figsize=figsize) # F1 by threshold for beats ax1 = axes[0, 0] if "beat_f1_by_threshold" in results: thresholds = sorted(results["beat_f1_by_threshold"].keys()) f1_scores = [results["beat_f1_by_threshold"][t] for t in thresholds] ax1.bar(range(len(thresholds)), f1_scores, color="steelblue", alpha=0.8) ax1.set_xticks(range(len(thresholds))) ax1.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45) ax1.set_ylabel("F1 Score") ax1.set_title("Beat F1 by Threshold") ax1.set_ylim(0, 1) ax1.grid(True, alpha=0.3) # F1 by threshold for downbeats ax2 = axes[0, 1] if "downbeat_f1_by_threshold" in results: thresholds = sorted(results["downbeat_f1_by_threshold"].keys()) f1_scores = [results["downbeat_f1_by_threshold"][t] for t in thresholds] ax2.bar(range(len(thresholds)), f1_scores, color="coral", alpha=0.8) ax2.set_xticks(range(len(thresholds))) ax2.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45) ax2.set_ylabel("F1 Score") ax2.set_title("Downbeat F1 by Threshold") ax2.set_ylim(0, 1) ax2.grid(True, alpha=0.3) # Continuity metrics for beats ax3 = axes[1, 0] if "beat_continuity" in results: metrics = ["CMLc", "CMLt", "AMLc", "AMLt"] values = [results["beat_continuity"][m] for m in metrics] colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"] bars = ax3.bar(metrics, values, color=colors, alpha=0.8) ax3.set_ylabel("Score") ax3.set_title("Beat Continuity Metrics") ax3.set_ylim(0, 1) ax3.grid(True, alpha=0.3) # Add value labels for bar, val in zip(bars, values): ax3.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, f"{val:.3f}", ha="center", fontsize=9, ) # Continuity metrics for downbeats ax4 = axes[1, 1] if "downbeat_continuity" in results: metrics = ["CMLc", "CMLt", "AMLc", "AMLt"] values = [results["downbeat_continuity"][m] for m in metrics] colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"] bars = ax4.bar(metrics, values, color=colors, alpha=0.8) ax4.set_ylabel("Score") ax4.set_title("Downbeat Continuity Metrics") ax4.set_ylim(0, 1) ax4.grid(True, alpha=0.3) # Add value labels for bar, val in zip(bars, values): ax4.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, f"{val:.3f}", ha="center", fontsize=9, ) fig.suptitle(title, fontsize=14, fontweight="bold") plt.tight_layout() return fig def save_figure( fig: "plt.Figure", path: str | Path, dpi: int = 150, ) -> None: """ Save a matplotlib figure to file. Args: fig: Figure to save path: Output file path dpi: Resolution (dots per inch) """ _check_matplotlib() path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(str(path), dpi=dpi, bbox_inches="tight") plt.close(fig) if __name__ == "__main__": # Demo _check_matplotlib() print("Visualization demo...") # Generate synthetic data np.random.seed(42) gt_beats = np.arange(0, 10, 0.5) gt_downbeats = np.arange(0, 10, 2.0) pred_beats = gt_beats + np.random.normal(0, 0.02, len(gt_beats)) pred_downbeats = gt_downbeats + np.random.normal(0, 0.01, len(gt_downbeats)) # Generate fake audio sr = 16000 duration = 10.0 t = np.arange(int(duration * sr)) / sr audio = np.sin(2 * np.pi * 220 * t) * 0.3 # Create plots fig1 = plot_beats( pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Beat Comparison Demo" ) save_figure(fig1, "/tmp/beat_comparison_demo.png") print("Saved /tmp/beat_comparison_demo.png") fig2 = plot_waveform_with_beats( audio, sr, pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Waveform Demo", time_range=(2, 8), ) save_figure(fig2, "/tmp/waveform_demo.png") print("Saved /tmp/waveform_demo.png") # Fake evaluation results results = { "beat_f1_by_threshold": { 3: 0.5, 6: 0.7, 9: 0.85, 12: 0.9, 15: 0.95, 18: 0.96, 21: 0.97, 24: 0.97, 27: 0.98, 30: 0.98, }, "downbeat_f1_by_threshold": { 3: 0.6, 6: 0.8, 9: 0.9, 12: 0.95, 15: 0.97, 18: 0.98, 21: 0.98, 24: 0.99, 27: 0.99, 30: 0.99, }, "beat_continuity": {"CMLc": 0.75, "CMLt": 0.92, "AMLc": 0.80, "AMLt": 0.95}, "downbeat_continuity": {"CMLc": 0.85, "CMLt": 0.95, "AMLc": 0.88, "AMLt": 0.97}, } fig3 = plot_evaluation_summary(results, title="Evaluation Summary Demo") save_figure(fig3, "/tmp/eval_summary_demo.png") print("Saved /tmp/eval_summary_demo.png")