""" Evaluation engine: compare extracted drum samples against ground truth. Runs the full pipeline on a synthetic song with known samples, then matches extracted clusters to ground-truth samples and computes metrics. """ import numpy as np import librosa import json from dataclasses import dataclass from typing import Optional from collections import defaultdict from quality_metrics import ( compute_all_reference_metrics, drum_sample_score, compute_si_sdr, compute_envelope_correlation, compute_mfcc_distance, ) @dataclass class MatchResult: """Result of matching one extracted cluster to a ground-truth sample.""" cluster_label: str gt_name: str si_sdr: float mfcc_distance: float envelope_corr: float spectral_convergence: float sample_score: float # production quality score n_hits_extracted: int n_hits_gt: int onset_precision_ms: float # mean onset error @dataclass class EvalReport: """Full evaluation report for one extraction run.""" matches: list # [MatchResult] unmatched_gt: list # GT samples with no extraction match unmatched_clusters: list # Extracted clusters with no GT match mean_si_sdr: float mean_mfcc_dist: float mean_env_corr: float mean_sample_score: float mean_onset_error_ms: float hit_count_accuracy: float # extracted vs GT hit counts overall_score: float # composite [0, 100] params_used: dict # pipeline params that produced this result def match_sample_to_gt(extracted: np.ndarray, gt_samples: dict, sr: int) -> tuple[str, float]: """Find the best-matching ground-truth sample for an extracted sample. Uses MFCC + envelope correlation for matching.""" best_name = None best_score = -np.inf for gt_name, gt_audio in gt_samples.items(): # Compute matching score n = min(len(extracted), len(gt_audio)) if n < 100: continue mfcc_dist = compute_mfcc_distance(gt_audio[:n], extracted[:n], sr) env_corr = compute_envelope_correlation(gt_audio[:n], extracted[:n]) # Combined matching score (lower mfcc_dist + higher env_corr = better) score = env_corr - mfcc_dist / 50.0 # normalize mfcc to similar scale if score > best_score: best_score = score best_name = gt_name return best_name, best_score def compute_onset_errors(extracted_hits: list, gt_hits: list, sample_name: str, tolerance_ms: float = 50.0) -> list: """Compute onset time errors between extracted and GT hits for a given sample type. Returns list of (gt_time, nearest_extracted_time, error_ms).""" gt_times = [h['onset'] for h in gt_hits if h['sample'] == sample_name] ext_times = sorted([h.onset_time for h in extracted_hits if sample_name in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()]) errors = [] for gt_t in gt_times: if not ext_times: errors.append((gt_t, None, tolerance_ms)) continue # Find nearest extracted onset diffs = [abs(gt_t - et) for et in ext_times] min_idx = np.argmin(diffs) error_ms = diffs[min_idx] * 1000 errors.append((gt_t, ext_times[min_idx], error_ms)) return errors def evaluate_extraction( extracted_clusters: list, gt_samples: dict, # {name: np.ndarray} gt_hit_map: list, # [{sample, onset, velocity}, ...] sr: int, all_hits: list = None, # all DrumHit objects for onset analysis pipeline_params: dict = None, ) -> EvalReport: """ Evaluate extracted clusters against ground truth. Args: extracted_clusters: list of DrumCluster from the pipeline gt_samples: {name: audio_array} ground-truth one-shots gt_hit_map: [{sample, onset, velocity}] from synthetic generator sr: sample rate all_hits: all DrumHit objects (for onset precision analysis) pipeline_params: the params used for this extraction run """ matches = [] matched_gt = set() matched_clusters = set() # Count GT hits per sample type gt_hit_counts = defaultdict(int) for h in gt_hit_map: gt_hit_counts[h['sample']] += 1 # Count extracted hits per rough label ext_hit_counts = defaultdict(int) for cluster in extracted_clusters: # Extract base label (e.g. "kick" from "kick_0") base = cluster.label.rsplit('_', 1)[0] ext_hit_counts[base] += cluster.count # For each cluster, find best GT match for cluster in extracted_clusters: best_hit = cluster.best_hit gt_name, match_score = match_sample_to_gt(best_hit.audio, gt_samples, sr) if gt_name is None: continue matched_gt.add(gt_name) matched_clusters.add(cluster.cluster_id) # Compute reference metrics gt_audio = gt_samples[gt_name] ext_audio = best_hit.audio n = min(len(gt_audio), len(ext_audio)) si_sdr = compute_si_sdr(gt_audio[:n], ext_audio[:n]) mfcc_dist = compute_mfcc_distance(gt_audio[:n], ext_audio[:n], sr) env_corr = compute_envelope_correlation(gt_audio[:n], ext_audio[:n]) ref_metrics = compute_all_reference_metrics(gt_audio[:n], ext_audio[:n], sr) # Sample quality score base_label = cluster.label.rsplit('_', 1)[0] score = drum_sample_score(ext_audio, sr, base_label) # Onset precision (if we have hit data) onset_errors = [] if all_hits: errors = compute_onset_errors( [h for h in all_hits if base_label in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()], gt_hit_map, gt_name ) onset_errors = [e[2] for e in errors if e[1] is not None] mean_onset_err = np.mean(onset_errors) if onset_errors else 50.0 matches.append(MatchResult( cluster_label=cluster.label, gt_name=gt_name, si_sdr=si_sdr, mfcc_distance=mfcc_dist, envelope_corr=env_corr, spectral_convergence=ref_metrics['Spectral Convergence'], sample_score=score['total'], n_hits_extracted=ext_hit_counts.get(base_label, 0), n_hits_gt=gt_hit_counts.get(gt_name, 0), onset_precision_ms=mean_onset_err, )) # Unmatched unmatched_gt = [n for n in gt_samples if n not in matched_gt] unmatched_clusters = [c.label for c in extracted_clusters if c.cluster_id not in matched_clusters] # Aggregate metrics if matches: mean_si_sdr = np.mean([m.si_sdr for m in matches]) mean_mfcc = np.mean([m.mfcc_distance for m in matches]) mean_env = np.mean([m.envelope_corr for m in matches]) mean_score = np.mean([m.sample_score for m in matches]) mean_onset = np.mean([m.onset_precision_ms for m in matches]) else: mean_si_sdr = -np.inf mean_mfcc = np.inf mean_env = 0.0 mean_score = 0.0 mean_onset = 50.0 # Hit count accuracy: how close are extracted counts to GT counts total_gt = sum(gt_hit_counts.values()) total_ext = sum(ext_hit_counts.values()) hit_acc = 1.0 - abs(total_gt - total_ext) / (total_gt + 1e-8) hit_acc = max(0, hit_acc) # Overall composite score # Weights: SI-SDR 25%, sample_score 25%, env_corr 20%, onset 15%, coverage 15% coverage = len(matched_gt) / (len(gt_samples) + 1e-8) si_sdr_norm = np.clip((mean_si_sdr + 5) / 25, 0, 1) # -5dB→0, 20dB→1 env_norm = np.clip(mean_env, 0, 1) onset_norm = np.clip(1.0 - mean_onset / 50.0, 0, 1) # 0ms→1, 50ms→0 score_norm = mean_score / 100.0 overall = (si_sdr_norm * 0.25 + score_norm * 0.25 + env_norm * 0.20 + onset_norm * 0.15 + coverage * 0.15) * 100 return EvalReport( matches=matches, unmatched_gt=unmatched_gt, unmatched_clusters=unmatched_clusters, mean_si_sdr=float(mean_si_sdr), mean_mfcc_dist=float(mean_mfcc), mean_env_corr=float(mean_env), mean_sample_score=float(mean_score), mean_onset_error_ms=float(mean_onset), hit_count_accuracy=float(hit_acc), overall_score=float(overall), params_used=pipeline_params or {}, ) def report_to_dict(report: EvalReport) -> dict: """Convert eval report to JSON-serializable dict.""" return { 'overall_score': report.overall_score, 'mean_si_sdr': report.mean_si_sdr, 'mean_mfcc_dist': report.mean_mfcc_dist, 'mean_env_corr': report.mean_env_corr, 'mean_sample_score': report.mean_sample_score, 'mean_onset_error_ms': report.mean_onset_error_ms, 'hit_count_accuracy': report.hit_count_accuracy, 'n_matched': len(report.matches), 'n_unmatched_gt': len(report.unmatched_gt), 'n_unmatched_clusters': len(report.unmatched_clusters), 'unmatched_gt': report.unmatched_gt, 'matches': [ { 'cluster': m.cluster_label, 'gt': m.gt_name, 'si_sdr': m.si_sdr, 'mfcc_dist': m.mfcc_distance, 'env_corr': m.envelope_corr, 'score': m.sample_score, 'onset_ms': m.onset_precision_ms, } for m in report.matches ], 'params': report.params_used, }