| """ |
| Evaluation engine: compare extracted drum samples against ground truth. |
| |
| Runs the full pipeline on a synthetic song with known samples, then |
| matches extracted clusters to ground-truth samples and computes metrics. |
| """ |
|
|
| import numpy as np |
| import librosa |
| import json |
| from dataclasses import dataclass |
| from typing import Optional |
| from collections import defaultdict |
|
|
| from quality_metrics import ( |
| compute_all_reference_metrics, |
| drum_sample_score, |
| compute_si_sdr, |
| compute_envelope_correlation, |
| compute_mfcc_distance, |
| ) |
|
|
|
|
| @dataclass |
| class MatchResult: |
| """Result of matching one extracted cluster to a ground-truth sample.""" |
| cluster_label: str |
| gt_name: str |
| si_sdr: float |
| mfcc_distance: float |
| envelope_corr: float |
| spectral_convergence: float |
| sample_score: float |
| n_hits_extracted: int |
| n_hits_gt: int |
| onset_precision_ms: float |
|
|
|
|
| @dataclass |
| class EvalReport: |
| """Full evaluation report for one extraction run.""" |
| matches: list |
| unmatched_gt: list |
| unmatched_clusters: list |
| mean_si_sdr: float |
| mean_mfcc_dist: float |
| mean_env_corr: float |
| mean_sample_score: float |
| mean_onset_error_ms: float |
| hit_count_accuracy: float |
| overall_score: float |
| params_used: dict |
|
|
|
|
| def match_sample_to_gt(extracted: np.ndarray, gt_samples: dict, |
| sr: int) -> tuple[str, float]: |
| """Find the best-matching ground-truth sample for an extracted sample. |
| Uses MFCC + envelope correlation for matching.""" |
| best_name = None |
| best_score = -np.inf |
|
|
| for gt_name, gt_audio in gt_samples.items(): |
| |
| n = min(len(extracted), len(gt_audio)) |
| if n < 100: |
| continue |
|
|
| mfcc_dist = compute_mfcc_distance(gt_audio[:n], extracted[:n], sr) |
| env_corr = compute_envelope_correlation(gt_audio[:n], extracted[:n]) |
|
|
| |
| score = env_corr - mfcc_dist / 50.0 |
|
|
| if score > best_score: |
| best_score = score |
| best_name = gt_name |
|
|
| return best_name, best_score |
|
|
|
|
| def compute_onset_errors(extracted_hits: list, gt_hits: list, |
| sample_name: str, tolerance_ms: float = 50.0) -> list: |
| """Compute onset time errors between extracted and GT hits for a given sample type. |
| Returns list of (gt_time, nearest_extracted_time, error_ms).""" |
| gt_times = [h['onset'] for h in gt_hits if h['sample'] == sample_name] |
| ext_times = sorted([h.onset_time for h in extracted_hits |
| if sample_name in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()]) |
|
|
| errors = [] |
| for gt_t in gt_times: |
| if not ext_times: |
| errors.append((gt_t, None, tolerance_ms)) |
| continue |
| |
| diffs = [abs(gt_t - et) for et in ext_times] |
| min_idx = np.argmin(diffs) |
| error_ms = diffs[min_idx] * 1000 |
| errors.append((gt_t, ext_times[min_idx], error_ms)) |
|
|
| return errors |
|
|
|
|
| def evaluate_extraction( |
| extracted_clusters: list, |
| gt_samples: dict, |
| gt_hit_map: list, |
| sr: int, |
| all_hits: list = None, |
| pipeline_params: dict = None, |
| ) -> EvalReport: |
| """ |
| Evaluate extracted clusters against ground truth. |
| |
| Args: |
| extracted_clusters: list of DrumCluster from the pipeline |
| gt_samples: {name: audio_array} ground-truth one-shots |
| gt_hit_map: [{sample, onset, velocity}] from synthetic generator |
| sr: sample rate |
| all_hits: all DrumHit objects (for onset precision analysis) |
| pipeline_params: the params used for this extraction run |
| """ |
| matches = [] |
| matched_gt = set() |
| matched_clusters = set() |
|
|
| |
| gt_hit_counts = defaultdict(int) |
| for h in gt_hit_map: |
| gt_hit_counts[h['sample']] += 1 |
|
|
| |
| ext_hit_counts = defaultdict(int) |
| for cluster in extracted_clusters: |
| |
| base = cluster.label.rsplit('_', 1)[0] |
| ext_hit_counts[base] += cluster.count |
|
|
| |
| for cluster in extracted_clusters: |
| best_hit = cluster.best_hit |
| gt_name, match_score = match_sample_to_gt(best_hit.audio, gt_samples, sr) |
|
|
| if gt_name is None: |
| continue |
|
|
| matched_gt.add(gt_name) |
| matched_clusters.add(cluster.cluster_id) |
|
|
| |
| gt_audio = gt_samples[gt_name] |
| ext_audio = best_hit.audio |
| n = min(len(gt_audio), len(ext_audio)) |
|
|
| si_sdr = compute_si_sdr(gt_audio[:n], ext_audio[:n]) |
| mfcc_dist = compute_mfcc_distance(gt_audio[:n], ext_audio[:n], sr) |
| env_corr = compute_envelope_correlation(gt_audio[:n], ext_audio[:n]) |
| ref_metrics = compute_all_reference_metrics(gt_audio[:n], ext_audio[:n], sr) |
|
|
| |
| base_label = cluster.label.rsplit('_', 1)[0] |
| score = drum_sample_score(ext_audio, sr, base_label) |
|
|
| |
| onset_errors = [] |
| if all_hits: |
| errors = compute_onset_errors( |
| [h for h in all_hits if base_label in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()], |
| gt_hit_map, gt_name |
| ) |
| onset_errors = [e[2] for e in errors if e[1] is not None] |
|
|
| mean_onset_err = np.mean(onset_errors) if onset_errors else 50.0 |
|
|
| matches.append(MatchResult( |
| cluster_label=cluster.label, |
| gt_name=gt_name, |
| si_sdr=si_sdr, |
| mfcc_distance=mfcc_dist, |
| envelope_corr=env_corr, |
| spectral_convergence=ref_metrics['Spectral Convergence'], |
| sample_score=score['total'], |
| n_hits_extracted=ext_hit_counts.get(base_label, 0), |
| n_hits_gt=gt_hit_counts.get(gt_name, 0), |
| onset_precision_ms=mean_onset_err, |
| )) |
|
|
| |
| unmatched_gt = [n for n in gt_samples if n not in matched_gt] |
| unmatched_clusters = [c.label for c in extracted_clusters |
| if c.cluster_id not in matched_clusters] |
|
|
| |
| if matches: |
| mean_si_sdr = np.mean([m.si_sdr for m in matches]) |
| mean_mfcc = np.mean([m.mfcc_distance for m in matches]) |
| mean_env = np.mean([m.envelope_corr for m in matches]) |
| mean_score = np.mean([m.sample_score for m in matches]) |
| mean_onset = np.mean([m.onset_precision_ms for m in matches]) |
| else: |
| mean_si_sdr = -np.inf |
| mean_mfcc = np.inf |
| mean_env = 0.0 |
| mean_score = 0.0 |
| mean_onset = 50.0 |
|
|
| |
| total_gt = sum(gt_hit_counts.values()) |
| total_ext = sum(ext_hit_counts.values()) |
| hit_acc = 1.0 - abs(total_gt - total_ext) / (total_gt + 1e-8) |
| hit_acc = max(0, hit_acc) |
|
|
| |
| |
| coverage = len(matched_gt) / (len(gt_samples) + 1e-8) |
| si_sdr_norm = np.clip((mean_si_sdr + 5) / 25, 0, 1) |
| env_norm = np.clip(mean_env, 0, 1) |
| onset_norm = np.clip(1.0 - mean_onset / 50.0, 0, 1) |
| score_norm = mean_score / 100.0 |
|
|
| overall = (si_sdr_norm * 0.25 + score_norm * 0.25 + env_norm * 0.20 + |
| onset_norm * 0.15 + coverage * 0.15) * 100 |
|
|
| return EvalReport( |
| matches=matches, |
| unmatched_gt=unmatched_gt, |
| unmatched_clusters=unmatched_clusters, |
| mean_si_sdr=float(mean_si_sdr), |
| mean_mfcc_dist=float(mean_mfcc), |
| mean_env_corr=float(mean_env), |
| mean_sample_score=float(mean_score), |
| mean_onset_error_ms=float(mean_onset), |
| hit_count_accuracy=float(hit_acc), |
| overall_score=float(overall), |
| params_used=pipeline_params or {}, |
| ) |
|
|
|
|
| def report_to_dict(report: EvalReport) -> dict: |
| """Convert eval report to JSON-serializable dict.""" |
| return { |
| 'overall_score': report.overall_score, |
| 'mean_si_sdr': report.mean_si_sdr, |
| 'mean_mfcc_dist': report.mean_mfcc_dist, |
| 'mean_env_corr': report.mean_env_corr, |
| 'mean_sample_score': report.mean_sample_score, |
| 'mean_onset_error_ms': report.mean_onset_error_ms, |
| 'hit_count_accuracy': report.hit_count_accuracy, |
| 'n_matched': len(report.matches), |
| 'n_unmatched_gt': len(report.unmatched_gt), |
| 'n_unmatched_clusters': len(report.unmatched_clusters), |
| 'unmatched_gt': report.unmatched_gt, |
| 'matches': [ |
| { |
| 'cluster': m.cluster_label, |
| 'gt': m.gt_name, |
| 'si_sdr': m.si_sdr, |
| 'mfcc_dist': m.mfcc_distance, |
| 'env_corr': m.envelope_corr, |
| 'score': m.sample_score, |
| 'onset_ms': m.onset_precision_ms, |
| } |
| for m in report.matches |
| ], |
| 'params': report.params_used, |
| } |
|
|