JacobLinCool's picture
Upload folder using huggingface_hub
31bf74c unverified
"""
Evaluation utilities for beat and downbeat detection.
This module provides functions to evaluate beat/downbeat predictions against
ground truth annotations using F1-scores at various timing thresholds and
continuity-based metrics (CMLt, AMLt).
The evaluation metrics include:
- **F1-scores**: Calculated for timing thresholds from 3ms to 30ms
- **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2)
- **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level
- **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations
(double/half tempo, off-beat, etc.)
- **CMLc/AMLc**: Continuous versions (longest correct segment)
Example usage:
from ..data.eval import (
evaluate_beats, evaluate_all, compute_weighted_f1,
compute_continuity_metrics, format_results
)
# Evaluate single track
results = evaluate_beats(pred_beats, gt_beats)
print(f"Weighted F1: {results['weighted_f1']:.4f}")
print(f"CMLt: {results['continuity']['CMLt']:.4f}")
print(f"AMLt: {results['continuity']['AMLt']:.4f}")
# Evaluate with custom thresholds
results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20])
# Evaluate all tracks in dataset
summary = evaluate_all(predictions, ground_truths)
print(format_results(summary))
"""
from typing import Sequence
import numpy as np
import mir_eval
# Default timing thresholds in milliseconds (3ms to 30ms, step 3ms)
DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30]
# Default minimum beat time for mir_eval metrics (can be set to 0 to use all beats)
DEFAULT_MIN_BEAT_TIME = 5.0
def match_events(
pred: np.ndarray,
gt: np.ndarray,
tolerance_sec: float,
) -> tuple[int, int, int]:
"""
Match predicted events to ground truth events within a tolerance.
Uses greedy matching: each ground truth event is matched to the closest
unmatched prediction within the tolerance window.
Args:
pred: Predicted event times in seconds, shape (N,)
gt: Ground truth event times in seconds, shape (M,)
tolerance_sec: Maximum time difference for a match (in seconds)
Returns:
Tuple of (true_positives, false_positives, false_negatives)
"""
if len(gt) == 0:
return 0, len(pred), 0
if len(pred) == 0:
return 0, 0, len(gt)
pred = np.sort(pred)
gt = np.sort(gt)
matched_pred = np.zeros(len(pred), dtype=bool)
matched_gt = np.zeros(len(gt), dtype=bool)
# For each ground truth, find the closest unmatched prediction
for i, gt_time in enumerate(gt):
# Find predictions within tolerance
diffs = np.abs(pred - gt_time)
candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0]
if len(candidates) > 0:
# Match to closest candidate
best_idx = candidates[np.argmin(diffs[candidates])]
matched_pred[best_idx] = True
matched_gt[i] = True
tp = int(matched_gt.sum())
fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp)
fn = int(len(gt) - tp)
# Recalculate fp correctly
fp = len(pred) - tp
return tp, fp, fn
def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
"""
Compute precision, recall, and F1-score from TP, FP, FN counts.
Args:
tp: True positives
fp: False positives
fn: False negatives
Returns:
Tuple of (precision, recall, f1_score)
"""
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
def compute_weighted_f1(
f1_scores: dict[int, float],
thresholds_ms: Sequence[int] | None = None,
) -> float:
"""
Compute weighted F1-score where weights are inversely proportional to threshold.
The weight for threshold T ms is 1 / (T / min_threshold).
For example, with thresholds [3, 6, 9, ...]:
- 3ms: weight = 1
- 6ms: weight = 0.5
- 9ms: weight = 0.333...
Args:
f1_scores: Dict mapping threshold (ms) to F1-score
thresholds_ms: List of thresholds used (for weight calculation)
Returns:
Weighted F1-score
"""
if not f1_scores:
return 0.0
if thresholds_ms is None:
thresholds_ms = sorted(f1_scores.keys())
min_threshold = min(thresholds_ms)
total_weight = 0.0
weighted_sum = 0.0
for t in thresholds_ms:
if t in f1_scores:
weight = min_threshold / t # 3ms -> 1, 6ms -> 0.5, etc.
weighted_sum += weight * f1_scores[t]
total_weight += weight
return weighted_sum / total_weight if total_weight > 0 else 0.0
def compute_continuity_metrics(
pred_times: Sequence[float],
gt_times: Sequence[float],
min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
phase_threshold: float = 0.175,
period_threshold: float = 0.175,
) -> dict:
"""
Compute continuity-based beat tracking metrics using mir_eval.
These metrics evaluate beat tracking accuracy accounting for metrical level:
- CMLt (Correct Metric Level Total): Accuracy at the correct metrical level
- AMLt (Any Metric Level Total): Accuracy allowing for metrical variations
(double/half tempo, off-beat, etc.)
- CMLc/AMLc: Continuous versions (longest correct segment)
Args:
pred_times: Predicted beat times in seconds
gt_times: Ground truth beat times in seconds
min_beat_time: Minimum time to start evaluation (default: 5.0s)
Set to 0.0 to use all beats, but note that early beats
may not have stable inter-beat intervals.
phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175)
period_threshold: Maximum period error as ratio of beat interval (default: 0.175)
Returns:
Dict containing:
- 'CMLc': Correct Metric Level Continuous
- 'CMLt': Correct Metric Level Total
- 'AMLc': Any Metric Level Continuous
- 'AMLt': Any Metric Level Total
"""
pred_arr = np.sort(np.array(pred_times, dtype=np.float64))
gt_arr = np.sort(np.array(gt_times, dtype=np.float64))
# Trim beats before min_beat_time (standard preprocessing)
pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time)
gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time)
# Handle edge cases where trimming results in too few beats
if len(gt_trimmed) < 2 or len(pred_trimmed) < 2:
return {
"CMLc": 0.0,
"CMLt": 0.0,
"AMLc": 0.0,
"AMLt": 0.0,
}
# Compute continuity metrics
CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(
gt_trimmed,
pred_trimmed,
continuity_phase_threshold=phase_threshold,
continuity_period_threshold=period_threshold,
)
return {
"CMLc": float(CMLc),
"CMLt": float(CMLt),
"AMLc": float(AMLc),
"AMLt": float(AMLt),
}
def evaluate_beats(
pred_times: Sequence[float],
gt_times: Sequence[float],
thresholds_ms: Sequence[int] | None = None,
min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
) -> dict:
"""
Evaluate beat predictions against ground truth at multiple thresholds.
Args:
pred_times: Predicted beat times in seconds
gt_times: Ground truth beat times in seconds
thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms)
min_beat_time: Minimum time for continuity metrics (default: 5.0s)
Returns:
Dict containing:
- 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}]
- 'f1_scores': Dict[threshold_ms, f1_score] (convenience access)
- 'weighted_f1': Weighted F1-score across all thresholds
- 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics
- 'num_predictions': Number of predictions
- 'num_ground_truth': Number of ground truth events
"""
if thresholds_ms is None:
thresholds_ms = DEFAULT_THRESHOLDS_MS
pred_arr = np.array(pred_times, dtype=np.float64)
gt_arr = np.array(gt_times, dtype=np.float64)
per_threshold = {}
f1_scores = {}
for threshold_ms in thresholds_ms:
tolerance_sec = threshold_ms / 1000.0
tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec)
precision, recall, f1 = compute_f1(tp, fp, fn)
per_threshold[threshold_ms] = {
"precision": precision,
"recall": recall,
"f1": f1,
"tp": tp,
"fp": fp,
"fn": fn,
}
f1_scores[threshold_ms] = f1
weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms)
continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time)
return {
"per_threshold": per_threshold,
"f1_scores": f1_scores,
"weighted_f1": weighted_f1,
"continuity": continuity,
"num_predictions": len(pred_arr),
"num_ground_truth": len(gt_arr),
}
def evaluate_track(
pred_beats: Sequence[float],
pred_downbeats: Sequence[float],
gt_beats: Sequence[float],
gt_downbeats: Sequence[float],
thresholds_ms: Sequence[int] | None = None,
min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
) -> dict:
"""
Evaluate both beat and downbeat predictions for a single track.
Args:
pred_beats: Predicted beat times in seconds
pred_downbeats: Predicted downbeat times in seconds
gt_beats: Ground truth beat times in seconds
gt_downbeats: Ground truth downbeat times in seconds
thresholds_ms: Timing thresholds in milliseconds
min_beat_time: Minimum time for continuity metrics (default: 5.0s)
Returns:
Dict containing:
- 'beats': Results from evaluate_beats for beats
- 'downbeats': Results from evaluate_beats for downbeats
- 'combined_weighted_f1': Average of beat and downbeat weighted F1
"""
beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time)
downbeat_results = evaluate_beats(
pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time
)
combined_weighted_f1 = (
beat_results["weighted_f1"] + downbeat_results["weighted_f1"]
) / 2
return {
"beats": beat_results,
"downbeats": downbeat_results,
"combined_weighted_f1": combined_weighted_f1,
}
def evaluate_all(
predictions: Sequence[dict],
ground_truths: Sequence[dict],
thresholds_ms: Sequence[int] | None = None,
min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
verbose: bool = False,
) -> dict:
"""
Evaluate predictions for multiple tracks.
Args:
predictions: List of dicts with 'beats' and 'downbeats' keys
ground_truths: List of dicts with 'beats' and 'downbeats' keys
thresholds_ms: Timing thresholds in milliseconds
min_beat_time: Minimum time for continuity metrics (default: 5.0s)
verbose: If True, print per-track results
Returns:
Dict containing:
- 'per_track': List of per-track results
- 'mean_beat_weighted_f1': Mean weighted F1 for beats
- 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats
- 'mean_combined_weighted_f1': Mean combined weighted F1
- 'beat_f1_by_threshold': Mean F1 per threshold for beats
- 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats
- 'beat_continuity': Mean continuity metrics for beats
- 'downbeat_continuity': Mean continuity metrics for downbeats
"""
if len(predictions) != len(ground_truths):
raise ValueError(
f"Number of predictions ({len(predictions)}) must match "
f"number of ground truths ({len(ground_truths)})"
)
if thresholds_ms is None:
thresholds_ms = DEFAULT_THRESHOLDS_MS
per_track = []
beat_weighted_f1s = []
downbeat_weighted_f1s = []
combined_weighted_f1s = []
beat_f1_by_threshold = {t: [] for t in thresholds_ms}
downbeat_f1_by_threshold = {t: [] for t in thresholds_ms}
# Continuity metrics tracking
beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
result = evaluate_track(
pred_beats=pred["beats"],
pred_downbeats=pred["downbeats"],
gt_beats=gt["beats"],
gt_downbeats=gt["downbeats"],
thresholds_ms=thresholds_ms,
min_beat_time=min_beat_time,
)
per_track.append(result)
beat_weighted_f1s.append(result["beats"]["weighted_f1"])
downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"])
combined_weighted_f1s.append(result["combined_weighted_f1"])
for t in thresholds_ms:
beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t])
downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t])
# Track continuity metrics
for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]:
beat_continuity[metric].append(result["beats"]["continuity"][metric])
downbeat_continuity[metric].append(
result["downbeats"]["continuity"][metric]
)
if verbose:
beat_cont = result["beats"]["continuity"]
print(
f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, "
f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, "
f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, "
f"Combined={result['combined_weighted_f1']:.4f}"
)
return {
"per_track": per_track,
"mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)),
"mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)),
"mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)),
"beat_f1_by_threshold": {
t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items()
},
"downbeat_f1_by_threshold": {
t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items()
},
"beat_continuity": {
metric: float(np.mean(values)) for metric, values in beat_continuity.items()
},
"downbeat_continuity": {
metric: float(np.mean(values))
for metric, values in downbeat_continuity.items()
},
"num_tracks": len(predictions),
}
def format_results(results: dict, title: str = "Evaluation Results") -> str:
"""
Format evaluation results as a human-readable string.
Args:
results: Results dict from evaluate_all or evaluate_track
title: Title for the report
Returns:
Formatted string report
"""
lines = [title, "=" * len(title), ""]
# Check if this is aggregate results (from evaluate_all)
if "num_tracks" in results:
lines.append(f"Number of tracks: {results['num_tracks']}")
lines.append("")
lines.append("Overall Metrics:")
lines.append(
f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}"
)
lines.append(
f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}"
)
lines.append(
f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}"
)
lines.append("")
lines.append("Beat F1 by Threshold:")
for t, f1 in sorted(results["beat_f1_by_threshold"].items()):
lines.append(f" {t:2d}ms: {f1:.4f}")
lines.append("")
lines.append("Downbeat F1 by Threshold:")
for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()):
lines.append(f" {t:2d}ms: {f1:.4f}")
lines.append("")
# Continuity metrics
if "beat_continuity" in results:
lines.append("Beat Continuity Metrics:")
bc = results["beat_continuity"]
lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)")
lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)")
lines.append(
f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)"
)
lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)")
lines.append("")
if "downbeat_continuity" in results:
lines.append("Downbeat Continuity Metrics:")
dc = results["downbeat_continuity"]
lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)")
lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)")
lines.append(
f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)"
)
lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)")
# Single track results (from evaluate_track)
elif "beats" in results and "downbeats" in results:
lines.append("Beat Detection:")
lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}")
lines.append(f" Predictions: {results['beats']['num_predictions']}")
lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}")
# Beat continuity metrics
if "continuity" in results["beats"]:
bc = results["beats"]["continuity"]
lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}")
lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}")
lines.append("")
lines.append("Downbeat Detection:")
lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}")
lines.append(f" Predictions: {results['downbeats']['num_predictions']}")
lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}")
# Downbeat continuity metrics
if "continuity" in results["downbeats"]:
dc = results["downbeats"]["continuity"]
lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}")
lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}")
lines.append("")
lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}")
return "\n".join(lines)
if __name__ == "__main__":
# Demo with synthetic data
print("Running evaluation demo...\n")
# Simulate ground truth beats at regular intervals (30s to have beats after 5s)
gt_beats = np.arange(0, 30, 0.5).tolist() # Beat every 0.5s for 30s
gt_downbeats = np.arange(0, 30, 2.0).tolist() # Downbeat every 2s
# Simulate predictions with some noise and missed/extra detections
np.random.seed(42)
pred_beats = (
np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats))
).tolist()
pred_beats = pred_beats[:-2] # Miss last 2 beats
pred_beats.append(15.25) # Add false positive
pred_downbeats = (
np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats))
).tolist()
# Evaluate single track
results = evaluate_track(
pred_beats=pred_beats,
pred_downbeats=pred_downbeats,
gt_beats=gt_beats,
gt_downbeats=gt_downbeats,
)
print(format_results(results, "Single Track Demo"))
print("\n" + "=" * 50 + "\n")
# Multi-track demo
predictions = [
{"beats": pred_beats, "downbeats": pred_downbeats},
{"beats": pred_beats, "downbeats": pred_downbeats},
]
ground_truths = [
{"beats": gt_beats, "downbeats": gt_downbeats},
{"beats": gt_beats, "downbeats": gt_downbeats},
]
all_results = evaluate_all(predictions, ground_truths, verbose=True)
print()
print(format_results(all_results, "Multi-Track Demo"))