|
|
""" |
|
|
Evaluation utilities for beat and downbeat detection. |
|
|
|
|
|
This module provides functions to evaluate beat/downbeat predictions against |
|
|
ground truth annotations using F1-scores at various timing thresholds and |
|
|
continuity-based metrics (CMLt, AMLt). |
|
|
|
|
|
The evaluation metrics include: |
|
|
- **F1-scores**: Calculated for timing thresholds from 3ms to 30ms |
|
|
- **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2) |
|
|
- **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level |
|
|
- **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations |
|
|
(double/half tempo, off-beat, etc.) |
|
|
- **CMLc/AMLc**: Continuous versions (longest correct segment) |
|
|
|
|
|
Example usage: |
|
|
from ..data.eval import ( |
|
|
evaluate_beats, evaluate_all, compute_weighted_f1, |
|
|
compute_continuity_metrics, format_results |
|
|
) |
|
|
|
|
|
# Evaluate single track |
|
|
results = evaluate_beats(pred_beats, gt_beats) |
|
|
print(f"Weighted F1: {results['weighted_f1']:.4f}") |
|
|
print(f"CMLt: {results['continuity']['CMLt']:.4f}") |
|
|
print(f"AMLt: {results['continuity']['AMLt']:.4f}") |
|
|
|
|
|
# Evaluate with custom thresholds |
|
|
results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20]) |
|
|
|
|
|
# Evaluate all tracks in dataset |
|
|
summary = evaluate_all(predictions, ground_truths) |
|
|
print(format_results(summary)) |
|
|
""" |
|
|
|
|
|
from typing import Sequence |
|
|
import numpy as np |
|
|
import mir_eval |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] |
|
|
|
|
|
|
|
|
DEFAULT_MIN_BEAT_TIME = 5.0 |
|
|
|
|
|
|
|
|
def match_events( |
|
|
pred: np.ndarray, |
|
|
gt: np.ndarray, |
|
|
tolerance_sec: float, |
|
|
) -> tuple[int, int, int]: |
|
|
""" |
|
|
Match predicted events to ground truth events within a tolerance. |
|
|
|
|
|
Uses greedy matching: each ground truth event is matched to the closest |
|
|
unmatched prediction within the tolerance window. |
|
|
|
|
|
Args: |
|
|
pred: Predicted event times in seconds, shape (N,) |
|
|
gt: Ground truth event times in seconds, shape (M,) |
|
|
tolerance_sec: Maximum time difference for a match (in seconds) |
|
|
|
|
|
Returns: |
|
|
Tuple of (true_positives, false_positives, false_negatives) |
|
|
""" |
|
|
if len(gt) == 0: |
|
|
return 0, len(pred), 0 |
|
|
if len(pred) == 0: |
|
|
return 0, 0, len(gt) |
|
|
|
|
|
pred = np.sort(pred) |
|
|
gt = np.sort(gt) |
|
|
|
|
|
matched_pred = np.zeros(len(pred), dtype=bool) |
|
|
matched_gt = np.zeros(len(gt), dtype=bool) |
|
|
|
|
|
|
|
|
for i, gt_time in enumerate(gt): |
|
|
|
|
|
diffs = np.abs(pred - gt_time) |
|
|
candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0] |
|
|
|
|
|
if len(candidates) > 0: |
|
|
|
|
|
best_idx = candidates[np.argmin(diffs[candidates])] |
|
|
matched_pred[best_idx] = True |
|
|
matched_gt[i] = True |
|
|
|
|
|
tp = int(matched_gt.sum()) |
|
|
fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp) |
|
|
fn = int(len(gt) - tp) |
|
|
|
|
|
|
|
|
fp = len(pred) - tp |
|
|
|
|
|
return tp, fp, fn |
|
|
|
|
|
|
|
|
def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]: |
|
|
""" |
|
|
Compute precision, recall, and F1-score from TP, FP, FN counts. |
|
|
|
|
|
Args: |
|
|
tp: True positives |
|
|
fp: False positives |
|
|
fn: False negatives |
|
|
|
|
|
Returns: |
|
|
Tuple of (precision, recall, f1_score) |
|
|
""" |
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
f1 = ( |
|
|
2 * precision * recall / (precision + recall) |
|
|
if (precision + recall) > 0 |
|
|
else 0.0 |
|
|
) |
|
|
return precision, recall, f1 |
|
|
|
|
|
|
|
|
def compute_weighted_f1( |
|
|
f1_scores: dict[int, float], |
|
|
thresholds_ms: Sequence[int] | None = None, |
|
|
) -> float: |
|
|
""" |
|
|
Compute weighted F1-score where weights are inversely proportional to threshold. |
|
|
|
|
|
The weight for threshold T ms is 1 / (T / min_threshold). |
|
|
For example, with thresholds [3, 6, 9, ...]: |
|
|
- 3ms: weight = 1 |
|
|
- 6ms: weight = 0.5 |
|
|
- 9ms: weight = 0.333... |
|
|
|
|
|
Args: |
|
|
f1_scores: Dict mapping threshold (ms) to F1-score |
|
|
thresholds_ms: List of thresholds used (for weight calculation) |
|
|
|
|
|
Returns: |
|
|
Weighted F1-score |
|
|
""" |
|
|
if not f1_scores: |
|
|
return 0.0 |
|
|
|
|
|
if thresholds_ms is None: |
|
|
thresholds_ms = sorted(f1_scores.keys()) |
|
|
|
|
|
min_threshold = min(thresholds_ms) |
|
|
total_weight = 0.0 |
|
|
weighted_sum = 0.0 |
|
|
|
|
|
for t in thresholds_ms: |
|
|
if t in f1_scores: |
|
|
weight = min_threshold / t |
|
|
weighted_sum += weight * f1_scores[t] |
|
|
total_weight += weight |
|
|
|
|
|
return weighted_sum / total_weight if total_weight > 0 else 0.0 |
|
|
|
|
|
|
|
|
def compute_continuity_metrics( |
|
|
pred_times: Sequence[float], |
|
|
gt_times: Sequence[float], |
|
|
min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
|
|
phase_threshold: float = 0.175, |
|
|
period_threshold: float = 0.175, |
|
|
) -> dict: |
|
|
""" |
|
|
Compute continuity-based beat tracking metrics using mir_eval. |
|
|
|
|
|
These metrics evaluate beat tracking accuracy accounting for metrical level: |
|
|
- CMLt (Correct Metric Level Total): Accuracy at the correct metrical level |
|
|
- AMLt (Any Metric Level Total): Accuracy allowing for metrical variations |
|
|
(double/half tempo, off-beat, etc.) |
|
|
- CMLc/AMLc: Continuous versions (longest correct segment) |
|
|
|
|
|
Args: |
|
|
pred_times: Predicted beat times in seconds |
|
|
gt_times: Ground truth beat times in seconds |
|
|
min_beat_time: Minimum time to start evaluation (default: 5.0s) |
|
|
Set to 0.0 to use all beats, but note that early beats |
|
|
may not have stable inter-beat intervals. |
|
|
phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175) |
|
|
period_threshold: Maximum period error as ratio of beat interval (default: 0.175) |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- 'CMLc': Correct Metric Level Continuous |
|
|
- 'CMLt': Correct Metric Level Total |
|
|
- 'AMLc': Any Metric Level Continuous |
|
|
- 'AMLt': Any Metric Level Total |
|
|
""" |
|
|
pred_arr = np.sort(np.array(pred_times, dtype=np.float64)) |
|
|
gt_arr = np.sort(np.array(gt_times, dtype=np.float64)) |
|
|
|
|
|
|
|
|
pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time) |
|
|
gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time) |
|
|
|
|
|
|
|
|
if len(gt_trimmed) < 2 or len(pred_trimmed) < 2: |
|
|
return { |
|
|
"CMLc": 0.0, |
|
|
"CMLt": 0.0, |
|
|
"AMLc": 0.0, |
|
|
"AMLt": 0.0, |
|
|
} |
|
|
|
|
|
|
|
|
CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity( |
|
|
gt_trimmed, |
|
|
pred_trimmed, |
|
|
continuity_phase_threshold=phase_threshold, |
|
|
continuity_period_threshold=period_threshold, |
|
|
) |
|
|
|
|
|
return { |
|
|
"CMLc": float(CMLc), |
|
|
"CMLt": float(CMLt), |
|
|
"AMLc": float(AMLc), |
|
|
"AMLt": float(AMLt), |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_beats( |
|
|
pred_times: Sequence[float], |
|
|
gt_times: Sequence[float], |
|
|
thresholds_ms: Sequence[int] | None = None, |
|
|
min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
|
|
) -> dict: |
|
|
""" |
|
|
Evaluate beat predictions against ground truth at multiple thresholds. |
|
|
|
|
|
Args: |
|
|
pred_times: Predicted beat times in seconds |
|
|
gt_times: Ground truth beat times in seconds |
|
|
thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms) |
|
|
min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}] |
|
|
- 'f1_scores': Dict[threshold_ms, f1_score] (convenience access) |
|
|
- 'weighted_f1': Weighted F1-score across all thresholds |
|
|
- 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics |
|
|
- 'num_predictions': Number of predictions |
|
|
- 'num_ground_truth': Number of ground truth events |
|
|
""" |
|
|
if thresholds_ms is None: |
|
|
thresholds_ms = DEFAULT_THRESHOLDS_MS |
|
|
|
|
|
pred_arr = np.array(pred_times, dtype=np.float64) |
|
|
gt_arr = np.array(gt_times, dtype=np.float64) |
|
|
|
|
|
per_threshold = {} |
|
|
f1_scores = {} |
|
|
|
|
|
for threshold_ms in thresholds_ms: |
|
|
tolerance_sec = threshold_ms / 1000.0 |
|
|
tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec) |
|
|
precision, recall, f1 = compute_f1(tp, fp, fn) |
|
|
|
|
|
per_threshold[threshold_ms] = { |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"f1": f1, |
|
|
"tp": tp, |
|
|
"fp": fp, |
|
|
"fn": fn, |
|
|
} |
|
|
f1_scores[threshold_ms] = f1 |
|
|
|
|
|
weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms) |
|
|
continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time) |
|
|
|
|
|
return { |
|
|
"per_threshold": per_threshold, |
|
|
"f1_scores": f1_scores, |
|
|
"weighted_f1": weighted_f1, |
|
|
"continuity": continuity, |
|
|
"num_predictions": len(pred_arr), |
|
|
"num_ground_truth": len(gt_arr), |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_track( |
|
|
pred_beats: Sequence[float], |
|
|
pred_downbeats: Sequence[float], |
|
|
gt_beats: Sequence[float], |
|
|
gt_downbeats: Sequence[float], |
|
|
thresholds_ms: Sequence[int] | None = None, |
|
|
min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
|
|
) -> dict: |
|
|
""" |
|
|
Evaluate both beat and downbeat predictions for a single track. |
|
|
|
|
|
Args: |
|
|
pred_beats: Predicted beat times in seconds |
|
|
pred_downbeats: Predicted downbeat times in seconds |
|
|
gt_beats: Ground truth beat times in seconds |
|
|
gt_downbeats: Ground truth downbeat times in seconds |
|
|
thresholds_ms: Timing thresholds in milliseconds |
|
|
min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- 'beats': Results from evaluate_beats for beats |
|
|
- 'downbeats': Results from evaluate_beats for downbeats |
|
|
- 'combined_weighted_f1': Average of beat and downbeat weighted F1 |
|
|
""" |
|
|
beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time) |
|
|
downbeat_results = evaluate_beats( |
|
|
pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time |
|
|
) |
|
|
|
|
|
combined_weighted_f1 = ( |
|
|
beat_results["weighted_f1"] + downbeat_results["weighted_f1"] |
|
|
) / 2 |
|
|
|
|
|
return { |
|
|
"beats": beat_results, |
|
|
"downbeats": downbeat_results, |
|
|
"combined_weighted_f1": combined_weighted_f1, |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_all( |
|
|
predictions: Sequence[dict], |
|
|
ground_truths: Sequence[dict], |
|
|
thresholds_ms: Sequence[int] | None = None, |
|
|
min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
|
|
verbose: bool = False, |
|
|
) -> dict: |
|
|
""" |
|
|
Evaluate predictions for multiple tracks. |
|
|
|
|
|
Args: |
|
|
predictions: List of dicts with 'beats' and 'downbeats' keys |
|
|
ground_truths: List of dicts with 'beats' and 'downbeats' keys |
|
|
thresholds_ms: Timing thresholds in milliseconds |
|
|
min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
|
|
verbose: If True, print per-track results |
|
|
|
|
|
Returns: |
|
|
Dict containing: |
|
|
- 'per_track': List of per-track results |
|
|
- 'mean_beat_weighted_f1': Mean weighted F1 for beats |
|
|
- 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats |
|
|
- 'mean_combined_weighted_f1': Mean combined weighted F1 |
|
|
- 'beat_f1_by_threshold': Mean F1 per threshold for beats |
|
|
- 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats |
|
|
- 'beat_continuity': Mean continuity metrics for beats |
|
|
- 'downbeat_continuity': Mean continuity metrics for downbeats |
|
|
""" |
|
|
if len(predictions) != len(ground_truths): |
|
|
raise ValueError( |
|
|
f"Number of predictions ({len(predictions)}) must match " |
|
|
f"number of ground truths ({len(ground_truths)})" |
|
|
) |
|
|
|
|
|
if thresholds_ms is None: |
|
|
thresholds_ms = DEFAULT_THRESHOLDS_MS |
|
|
|
|
|
per_track = [] |
|
|
beat_weighted_f1s = [] |
|
|
downbeat_weighted_f1s = [] |
|
|
combined_weighted_f1s = [] |
|
|
|
|
|
beat_f1_by_threshold = {t: [] for t in thresholds_ms} |
|
|
downbeat_f1_by_threshold = {t: [] for t in thresholds_ms} |
|
|
|
|
|
|
|
|
beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
|
|
downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
|
|
|
|
|
for i, (pred, gt) in enumerate(zip(predictions, ground_truths)): |
|
|
result = evaluate_track( |
|
|
pred_beats=pred["beats"], |
|
|
pred_downbeats=pred["downbeats"], |
|
|
gt_beats=gt["beats"], |
|
|
gt_downbeats=gt["downbeats"], |
|
|
thresholds_ms=thresholds_ms, |
|
|
min_beat_time=min_beat_time, |
|
|
) |
|
|
|
|
|
per_track.append(result) |
|
|
beat_weighted_f1s.append(result["beats"]["weighted_f1"]) |
|
|
downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"]) |
|
|
combined_weighted_f1s.append(result["combined_weighted_f1"]) |
|
|
|
|
|
for t in thresholds_ms: |
|
|
beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t]) |
|
|
downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t]) |
|
|
|
|
|
|
|
|
for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]: |
|
|
beat_continuity[metric].append(result["beats"]["continuity"][metric]) |
|
|
downbeat_continuity[metric].append( |
|
|
result["downbeats"]["continuity"][metric] |
|
|
) |
|
|
|
|
|
if verbose: |
|
|
beat_cont = result["beats"]["continuity"] |
|
|
print( |
|
|
f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, " |
|
|
f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, " |
|
|
f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, " |
|
|
f"Combined={result['combined_weighted_f1']:.4f}" |
|
|
) |
|
|
|
|
|
return { |
|
|
"per_track": per_track, |
|
|
"mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)), |
|
|
"mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)), |
|
|
"mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)), |
|
|
"beat_f1_by_threshold": { |
|
|
t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items() |
|
|
}, |
|
|
"downbeat_f1_by_threshold": { |
|
|
t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items() |
|
|
}, |
|
|
"beat_continuity": { |
|
|
metric: float(np.mean(values)) for metric, values in beat_continuity.items() |
|
|
}, |
|
|
"downbeat_continuity": { |
|
|
metric: float(np.mean(values)) |
|
|
for metric, values in downbeat_continuity.items() |
|
|
}, |
|
|
"num_tracks": len(predictions), |
|
|
} |
|
|
|
|
|
|
|
|
def format_results(results: dict, title: str = "Evaluation Results") -> str: |
|
|
""" |
|
|
Format evaluation results as a human-readable string. |
|
|
|
|
|
Args: |
|
|
results: Results dict from evaluate_all or evaluate_track |
|
|
title: Title for the report |
|
|
|
|
|
Returns: |
|
|
Formatted string report |
|
|
""" |
|
|
lines = [title, "=" * len(title), ""] |
|
|
|
|
|
|
|
|
if "num_tracks" in results: |
|
|
lines.append(f"Number of tracks: {results['num_tracks']}") |
|
|
lines.append("") |
|
|
lines.append("Overall Metrics:") |
|
|
lines.append( |
|
|
f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}" |
|
|
) |
|
|
lines.append( |
|
|
f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}" |
|
|
) |
|
|
lines.append( |
|
|
f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}" |
|
|
) |
|
|
lines.append("") |
|
|
|
|
|
lines.append("Beat F1 by Threshold:") |
|
|
for t, f1 in sorted(results["beat_f1_by_threshold"].items()): |
|
|
lines.append(f" {t:2d}ms: {f1:.4f}") |
|
|
lines.append("") |
|
|
|
|
|
lines.append("Downbeat F1 by Threshold:") |
|
|
for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()): |
|
|
lines.append(f" {t:2d}ms: {f1:.4f}") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if "beat_continuity" in results: |
|
|
lines.append("Beat Continuity Metrics:") |
|
|
bc = results["beat_continuity"] |
|
|
lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)") |
|
|
lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)") |
|
|
lines.append( |
|
|
f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
|
|
) |
|
|
lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)") |
|
|
lines.append("") |
|
|
|
|
|
if "downbeat_continuity" in results: |
|
|
lines.append("Downbeat Continuity Metrics:") |
|
|
dc = results["downbeat_continuity"] |
|
|
lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)") |
|
|
lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)") |
|
|
lines.append( |
|
|
f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
|
|
) |
|
|
lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)") |
|
|
|
|
|
|
|
|
elif "beats" in results and "downbeats" in results: |
|
|
lines.append("Beat Detection:") |
|
|
lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}") |
|
|
lines.append(f" Predictions: {results['beats']['num_predictions']}") |
|
|
lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}") |
|
|
|
|
|
|
|
|
if "continuity" in results["beats"]: |
|
|
bc = results["beats"]["continuity"] |
|
|
lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}") |
|
|
lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}") |
|
|
lines.append("") |
|
|
|
|
|
lines.append("Downbeat Detection:") |
|
|
lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}") |
|
|
lines.append(f" Predictions: {results['downbeats']['num_predictions']}") |
|
|
lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}") |
|
|
|
|
|
|
|
|
if "continuity" in results["downbeats"]: |
|
|
dc = results["downbeats"]["continuity"] |
|
|
lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}") |
|
|
lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}") |
|
|
lines.append("") |
|
|
|
|
|
lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Running evaluation demo...\n") |
|
|
|
|
|
|
|
|
gt_beats = np.arange(0, 30, 0.5).tolist() |
|
|
gt_downbeats = np.arange(0, 30, 2.0).tolist() |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
pred_beats = ( |
|
|
np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats)) |
|
|
).tolist() |
|
|
pred_beats = pred_beats[:-2] |
|
|
pred_beats.append(15.25) |
|
|
|
|
|
pred_downbeats = ( |
|
|
np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats)) |
|
|
).tolist() |
|
|
|
|
|
|
|
|
results = evaluate_track( |
|
|
pred_beats=pred_beats, |
|
|
pred_downbeats=pred_downbeats, |
|
|
gt_beats=gt_beats, |
|
|
gt_downbeats=gt_downbeats, |
|
|
) |
|
|
|
|
|
print(format_results(results, "Single Track Demo")) |
|
|
print("\n" + "=" * 50 + "\n") |
|
|
|
|
|
|
|
|
predictions = [ |
|
|
{"beats": pred_beats, "downbeats": pred_downbeats}, |
|
|
{"beats": pred_beats, "downbeats": pred_downbeats}, |
|
|
] |
|
|
ground_truths = [ |
|
|
{"beats": gt_beats, "downbeats": gt_downbeats}, |
|
|
{"beats": gt_beats, "downbeats": gt_downbeats}, |
|
|
] |
|
|
|
|
|
all_results = evaluate_all(predictions, ground_truths, verbose=True) |
|
|
print() |
|
|
print(format_results(all_results, "Multi-Track Demo")) |
|
|
|