""" Evaluation utilities for beat and downbeat detection. This module provides functions to evaluate beat/downbeat predictions against ground truth annotations using F1-scores at various timing thresholds and continuity-based metrics (CMLt, AMLt). The evaluation metrics include: - **F1-scores**: Calculated for timing thresholds from 3ms to 30ms - **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2) - **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level - **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations (double/half tempo, off-beat, etc.) - **CMLc/AMLc**: Continuous versions (longest correct segment) Example usage: from ..data.eval import ( evaluate_beats, evaluate_all, compute_weighted_f1, compute_continuity_metrics, format_results ) # Evaluate single track results = evaluate_beats(pred_beats, gt_beats) print(f"Weighted F1: {results['weighted_f1']:.4f}") print(f"CMLt: {results['continuity']['CMLt']:.4f}") print(f"AMLt: {results['continuity']['AMLt']:.4f}") # Evaluate with custom thresholds results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20]) # Evaluate all tracks in dataset summary = evaluate_all(predictions, ground_truths) print(format_results(summary)) """ from typing import Sequence import numpy as np import mir_eval # Default timing thresholds in milliseconds (3ms to 30ms, step 3ms) DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] # Default minimum beat time for mir_eval metrics (can be set to 0 to use all beats) DEFAULT_MIN_BEAT_TIME = 5.0 def match_events( pred: np.ndarray, gt: np.ndarray, tolerance_sec: float, ) -> tuple[int, int, int]: """ Match predicted events to ground truth events within a tolerance. Uses greedy matching: each ground truth event is matched to the closest unmatched prediction within the tolerance window. Args: pred: Predicted event times in seconds, shape (N,) gt: Ground truth event times in seconds, shape (M,) tolerance_sec: Maximum time difference for a match (in seconds) Returns: Tuple of (true_positives, false_positives, false_negatives) """ if len(gt) == 0: return 0, len(pred), 0 if len(pred) == 0: return 0, 0, len(gt) pred = np.sort(pred) gt = np.sort(gt) matched_pred = np.zeros(len(pred), dtype=bool) matched_gt = np.zeros(len(gt), dtype=bool) # For each ground truth, find the closest unmatched prediction for i, gt_time in enumerate(gt): # Find predictions within tolerance diffs = np.abs(pred - gt_time) candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0] if len(candidates) > 0: # Match to closest candidate best_idx = candidates[np.argmin(diffs[candidates])] matched_pred[best_idx] = True matched_gt[i] = True tp = int(matched_gt.sum()) fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp) fn = int(len(gt) - tp) # Recalculate fp correctly fp = len(pred) - tp return tp, fp, fn def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]: """ Compute precision, recall, and F1-score from TP, FP, FN counts. Args: tp: True positives fp: False positives fn: False negatives Returns: Tuple of (precision, recall, f1_score) """ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = ( 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 ) return precision, recall, f1 def compute_weighted_f1( f1_scores: dict[int, float], thresholds_ms: Sequence[int] | None = None, ) -> float: """ Compute weighted F1-score where weights are inversely proportional to threshold. The weight for threshold T ms is 1 / (T / min_threshold). For example, with thresholds [3, 6, 9, ...]: - 3ms: weight = 1 - 6ms: weight = 0.5 - 9ms: weight = 0.333... Args: f1_scores: Dict mapping threshold (ms) to F1-score thresholds_ms: List of thresholds used (for weight calculation) Returns: Weighted F1-score """ if not f1_scores: return 0.0 if thresholds_ms is None: thresholds_ms = sorted(f1_scores.keys()) min_threshold = min(thresholds_ms) total_weight = 0.0 weighted_sum = 0.0 for t in thresholds_ms: if t in f1_scores: weight = min_threshold / t # 3ms -> 1, 6ms -> 0.5, etc. weighted_sum += weight * f1_scores[t] total_weight += weight return weighted_sum / total_weight if total_weight > 0 else 0.0 def compute_continuity_metrics( pred_times: Sequence[float], gt_times: Sequence[float], min_beat_time: float = DEFAULT_MIN_BEAT_TIME, phase_threshold: float = 0.175, period_threshold: float = 0.175, ) -> dict: """ Compute continuity-based beat tracking metrics using mir_eval. These metrics evaluate beat tracking accuracy accounting for metrical level: - CMLt (Correct Metric Level Total): Accuracy at the correct metrical level - AMLt (Any Metric Level Total): Accuracy allowing for metrical variations (double/half tempo, off-beat, etc.) - CMLc/AMLc: Continuous versions (longest correct segment) Args: pred_times: Predicted beat times in seconds gt_times: Ground truth beat times in seconds min_beat_time: Minimum time to start evaluation (default: 5.0s) Set to 0.0 to use all beats, but note that early beats may not have stable inter-beat intervals. phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175) period_threshold: Maximum period error as ratio of beat interval (default: 0.175) Returns: Dict containing: - 'CMLc': Correct Metric Level Continuous - 'CMLt': Correct Metric Level Total - 'AMLc': Any Metric Level Continuous - 'AMLt': Any Metric Level Total """ pred_arr = np.sort(np.array(pred_times, dtype=np.float64)) gt_arr = np.sort(np.array(gt_times, dtype=np.float64)) # Trim beats before min_beat_time (standard preprocessing) pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time) gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time) # Handle edge cases where trimming results in too few beats if len(gt_trimmed) < 2 or len(pred_trimmed) < 2: return { "CMLc": 0.0, "CMLt": 0.0, "AMLc": 0.0, "AMLt": 0.0, } # Compute continuity metrics CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity( gt_trimmed, pred_trimmed, continuity_phase_threshold=phase_threshold, continuity_period_threshold=period_threshold, ) return { "CMLc": float(CMLc), "CMLt": float(CMLt), "AMLc": float(AMLc), "AMLt": float(AMLt), } def evaluate_beats( pred_times: Sequence[float], gt_times: Sequence[float], thresholds_ms: Sequence[int] | None = None, min_beat_time: float = DEFAULT_MIN_BEAT_TIME, ) -> dict: """ Evaluate beat predictions against ground truth at multiple thresholds. Args: pred_times: Predicted beat times in seconds gt_times: Ground truth beat times in seconds thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms) min_beat_time: Minimum time for continuity metrics (default: 5.0s) Returns: Dict containing: - 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}] - 'f1_scores': Dict[threshold_ms, f1_score] (convenience access) - 'weighted_f1': Weighted F1-score across all thresholds - 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics - 'num_predictions': Number of predictions - 'num_ground_truth': Number of ground truth events """ if thresholds_ms is None: thresholds_ms = DEFAULT_THRESHOLDS_MS pred_arr = np.array(pred_times, dtype=np.float64) gt_arr = np.array(gt_times, dtype=np.float64) per_threshold = {} f1_scores = {} for threshold_ms in thresholds_ms: tolerance_sec = threshold_ms / 1000.0 tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec) precision, recall, f1 = compute_f1(tp, fp, fn) per_threshold[threshold_ms] = { "precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn, } f1_scores[threshold_ms] = f1 weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms) continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time) return { "per_threshold": per_threshold, "f1_scores": f1_scores, "weighted_f1": weighted_f1, "continuity": continuity, "num_predictions": len(pred_arr), "num_ground_truth": len(gt_arr), } def evaluate_track( pred_beats: Sequence[float], pred_downbeats: Sequence[float], gt_beats: Sequence[float], gt_downbeats: Sequence[float], thresholds_ms: Sequence[int] | None = None, min_beat_time: float = DEFAULT_MIN_BEAT_TIME, ) -> dict: """ Evaluate both beat and downbeat predictions for a single track. Args: pred_beats: Predicted beat times in seconds pred_downbeats: Predicted downbeat times in seconds gt_beats: Ground truth beat times in seconds gt_downbeats: Ground truth downbeat times in seconds thresholds_ms: Timing thresholds in milliseconds min_beat_time: Minimum time for continuity metrics (default: 5.0s) Returns: Dict containing: - 'beats': Results from evaluate_beats for beats - 'downbeats': Results from evaluate_beats for downbeats - 'combined_weighted_f1': Average of beat and downbeat weighted F1 """ beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time) downbeat_results = evaluate_beats( pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time ) combined_weighted_f1 = ( beat_results["weighted_f1"] + downbeat_results["weighted_f1"] ) / 2 return { "beats": beat_results, "downbeats": downbeat_results, "combined_weighted_f1": combined_weighted_f1, } def evaluate_all( predictions: Sequence[dict], ground_truths: Sequence[dict], thresholds_ms: Sequence[int] | None = None, min_beat_time: float = DEFAULT_MIN_BEAT_TIME, verbose: bool = False, ) -> dict: """ Evaluate predictions for multiple tracks. Args: predictions: List of dicts with 'beats' and 'downbeats' keys ground_truths: List of dicts with 'beats' and 'downbeats' keys thresholds_ms: Timing thresholds in milliseconds min_beat_time: Minimum time for continuity metrics (default: 5.0s) verbose: If True, print per-track results Returns: Dict containing: - 'per_track': List of per-track results - 'mean_beat_weighted_f1': Mean weighted F1 for beats - 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats - 'mean_combined_weighted_f1': Mean combined weighted F1 - 'beat_f1_by_threshold': Mean F1 per threshold for beats - 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats - 'beat_continuity': Mean continuity metrics for beats - 'downbeat_continuity': Mean continuity metrics for downbeats """ if len(predictions) != len(ground_truths): raise ValueError( f"Number of predictions ({len(predictions)}) must match " f"number of ground truths ({len(ground_truths)})" ) if thresholds_ms is None: thresholds_ms = DEFAULT_THRESHOLDS_MS per_track = [] beat_weighted_f1s = [] downbeat_weighted_f1s = [] combined_weighted_f1s = [] beat_f1_by_threshold = {t: [] for t in thresholds_ms} downbeat_f1_by_threshold = {t: [] for t in thresholds_ms} # Continuity metrics tracking beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} for i, (pred, gt) in enumerate(zip(predictions, ground_truths)): result = evaluate_track( pred_beats=pred["beats"], pred_downbeats=pred["downbeats"], gt_beats=gt["beats"], gt_downbeats=gt["downbeats"], thresholds_ms=thresholds_ms, min_beat_time=min_beat_time, ) per_track.append(result) beat_weighted_f1s.append(result["beats"]["weighted_f1"]) downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"]) combined_weighted_f1s.append(result["combined_weighted_f1"]) for t in thresholds_ms: beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t]) downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t]) # Track continuity metrics for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]: beat_continuity[metric].append(result["beats"]["continuity"][metric]) downbeat_continuity[metric].append( result["downbeats"]["continuity"][metric] ) if verbose: beat_cont = result["beats"]["continuity"] print( f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, " f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, " f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, " f"Combined={result['combined_weighted_f1']:.4f}" ) return { "per_track": per_track, "mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)), "mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)), "mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)), "beat_f1_by_threshold": { t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items() }, "downbeat_f1_by_threshold": { t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items() }, "beat_continuity": { metric: float(np.mean(values)) for metric, values in beat_continuity.items() }, "downbeat_continuity": { metric: float(np.mean(values)) for metric, values in downbeat_continuity.items() }, "num_tracks": len(predictions), } def format_results(results: dict, title: str = "Evaluation Results") -> str: """ Format evaluation results as a human-readable string. Args: results: Results dict from evaluate_all or evaluate_track title: Title for the report Returns: Formatted string report """ lines = [title, "=" * len(title), ""] # Check if this is aggregate results (from evaluate_all) if "num_tracks" in results: lines.append(f"Number of tracks: {results['num_tracks']}") lines.append("") lines.append("Overall Metrics:") lines.append( f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}" ) lines.append( f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}" ) lines.append( f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}" ) lines.append("") lines.append("Beat F1 by Threshold:") for t, f1 in sorted(results["beat_f1_by_threshold"].items()): lines.append(f" {t:2d}ms: {f1:.4f}") lines.append("") lines.append("Downbeat F1 by Threshold:") for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()): lines.append(f" {t:2d}ms: {f1:.4f}") lines.append("") # Continuity metrics if "beat_continuity" in results: lines.append("Beat Continuity Metrics:") bc = results["beat_continuity"] lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)") lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)") lines.append( f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)" ) lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)") lines.append("") if "downbeat_continuity" in results: lines.append("Downbeat Continuity Metrics:") dc = results["downbeat_continuity"] lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)") lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)") lines.append( f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)" ) lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)") # Single track results (from evaluate_track) elif "beats" in results and "downbeats" in results: lines.append("Beat Detection:") lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}") lines.append(f" Predictions: {results['beats']['num_predictions']}") lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}") # Beat continuity metrics if "continuity" in results["beats"]: bc = results["beats"]["continuity"] lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}") lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}") lines.append("") lines.append("Downbeat Detection:") lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}") lines.append(f" Predictions: {results['downbeats']['num_predictions']}") lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}") # Downbeat continuity metrics if "continuity" in results["downbeats"]: dc = results["downbeats"]["continuity"] lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}") lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}") lines.append("") lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}") return "\n".join(lines) if __name__ == "__main__": # Demo with synthetic data print("Running evaluation demo...\n") # Simulate ground truth beats at regular intervals (30s to have beats after 5s) gt_beats = np.arange(0, 30, 0.5).tolist() # Beat every 0.5s for 30s gt_downbeats = np.arange(0, 30, 2.0).tolist() # Downbeat every 2s # Simulate predictions with some noise and missed/extra detections np.random.seed(42) pred_beats = ( np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats)) ).tolist() pred_beats = pred_beats[:-2] # Miss last 2 beats pred_beats.append(15.25) # Add false positive pred_downbeats = ( np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats)) ).tolist() # Evaluate single track results = evaluate_track( pred_beats=pred_beats, pred_downbeats=pred_downbeats, gt_beats=gt_beats, gt_downbeats=gt_downbeats, ) print(format_results(results, "Single Track Demo")) print("\n" + "=" * 50 + "\n") # Multi-track demo predictions = [ {"beats": pred_beats, "downbeats": pred_downbeats}, {"beats": pred_beats, "downbeats": pred_downbeats}, ] ground_truths = [ {"beats": gt_beats, "downbeats": gt_downbeats}, {"beats": gt_beats, "downbeats": gt_downbeats}, ] all_results = evaluate_all(predictions, ground_truths, verbose=True) print() print(format_results(all_results, "Multi-Track Demo"))