drum-sample-extractor / evaluation.py
rikhoffbauer2's picture
v2: Update evaluation.py
1b1f8d9 verified
"""
Evaluation engine: compare extracted drum samples against ground truth.
Runs the full pipeline on a synthetic song with known samples, then
matches extracted clusters to ground-truth samples and computes metrics.
"""
import numpy as np
import librosa
import json
from dataclasses import dataclass
from typing import Optional
from collections import defaultdict
from quality_metrics import (
compute_all_reference_metrics,
drum_sample_score,
compute_si_sdr,
compute_envelope_correlation,
compute_mfcc_distance,
)
@dataclass
class MatchResult:
"""Result of matching one extracted cluster to a ground-truth sample."""
cluster_label: str
gt_name: str
si_sdr: float
mfcc_distance: float
envelope_corr: float
spectral_convergence: float
sample_score: float # production quality score
n_hits_extracted: int
n_hits_gt: int
onset_precision_ms: float # mean onset error
@dataclass
class EvalReport:
"""Full evaluation report for one extraction run."""
matches: list # [MatchResult]
unmatched_gt: list # GT samples with no extraction match
unmatched_clusters: list # Extracted clusters with no GT match
mean_si_sdr: float
mean_mfcc_dist: float
mean_env_corr: float
mean_sample_score: float
mean_onset_error_ms: float
hit_count_accuracy: float # extracted vs GT hit counts
overall_score: float # composite [0, 100]
params_used: dict # pipeline params that produced this result
def match_sample_to_gt(extracted: np.ndarray, gt_samples: dict,
sr: int) -> tuple[str, float]:
"""Find the best-matching ground-truth sample for an extracted sample.
Uses MFCC + envelope correlation for matching."""
best_name = None
best_score = -np.inf
for gt_name, gt_audio in gt_samples.items():
# Compute matching score
n = min(len(extracted), len(gt_audio))
if n < 100:
continue
mfcc_dist = compute_mfcc_distance(gt_audio[:n], extracted[:n], sr)
env_corr = compute_envelope_correlation(gt_audio[:n], extracted[:n])
# Combined matching score (lower mfcc_dist + higher env_corr = better)
score = env_corr - mfcc_dist / 50.0 # normalize mfcc to similar scale
if score > best_score:
best_score = score
best_name = gt_name
return best_name, best_score
def compute_onset_errors(extracted_hits: list, gt_hits: list,
sample_name: str, tolerance_ms: float = 50.0) -> list:
"""Compute onset time errors between extracted and GT hits for a given sample type.
Returns list of (gt_time, nearest_extracted_time, error_ms)."""
gt_times = [h['onset'] for h in gt_hits if h['sample'] == sample_name]
ext_times = sorted([h.onset_time for h in extracted_hits
if sample_name in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()])
errors = []
for gt_t in gt_times:
if not ext_times:
errors.append((gt_t, None, tolerance_ms))
continue
# Find nearest extracted onset
diffs = [abs(gt_t - et) for et in ext_times]
min_idx = np.argmin(diffs)
error_ms = diffs[min_idx] * 1000
errors.append((gt_t, ext_times[min_idx], error_ms))
return errors
def evaluate_extraction(
extracted_clusters: list,
gt_samples: dict, # {name: np.ndarray}
gt_hit_map: list, # [{sample, onset, velocity}, ...]
sr: int,
all_hits: list = None, # all DrumHit objects for onset analysis
pipeline_params: dict = None,
) -> EvalReport:
"""
Evaluate extracted clusters against ground truth.
Args:
extracted_clusters: list of DrumCluster from the pipeline
gt_samples: {name: audio_array} ground-truth one-shots
gt_hit_map: [{sample, onset, velocity}] from synthetic generator
sr: sample rate
all_hits: all DrumHit objects (for onset precision analysis)
pipeline_params: the params used for this extraction run
"""
matches = []
matched_gt = set()
matched_clusters = set()
# Count GT hits per sample type
gt_hit_counts = defaultdict(int)
for h in gt_hit_map:
gt_hit_counts[h['sample']] += 1
# Count extracted hits per rough label
ext_hit_counts = defaultdict(int)
for cluster in extracted_clusters:
# Extract base label (e.g. "kick" from "kick_0")
base = cluster.label.rsplit('_', 1)[0]
ext_hit_counts[base] += cluster.count
# For each cluster, find best GT match
for cluster in extracted_clusters:
best_hit = cluster.best_hit
gt_name, match_score = match_sample_to_gt(best_hit.audio, gt_samples, sr)
if gt_name is None:
continue
matched_gt.add(gt_name)
matched_clusters.add(cluster.cluster_id)
# Compute reference metrics
gt_audio = gt_samples[gt_name]
ext_audio = best_hit.audio
n = min(len(gt_audio), len(ext_audio))
si_sdr = compute_si_sdr(gt_audio[:n], ext_audio[:n])
mfcc_dist = compute_mfcc_distance(gt_audio[:n], ext_audio[:n], sr)
env_corr = compute_envelope_correlation(gt_audio[:n], ext_audio[:n])
ref_metrics = compute_all_reference_metrics(gt_audio[:n], ext_audio[:n], sr)
# Sample quality score
base_label = cluster.label.rsplit('_', 1)[0]
score = drum_sample_score(ext_audio, sr, base_label)
# Onset precision (if we have hit data)
onset_errors = []
if all_hits:
errors = compute_onset_errors(
[h for h in all_hits if base_label in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()],
gt_hit_map, gt_name
)
onset_errors = [e[2] for e in errors if e[1] is not None]
mean_onset_err = np.mean(onset_errors) if onset_errors else 50.0
matches.append(MatchResult(
cluster_label=cluster.label,
gt_name=gt_name,
si_sdr=si_sdr,
mfcc_distance=mfcc_dist,
envelope_corr=env_corr,
spectral_convergence=ref_metrics['Spectral Convergence'],
sample_score=score['total'],
n_hits_extracted=ext_hit_counts.get(base_label, 0),
n_hits_gt=gt_hit_counts.get(gt_name, 0),
onset_precision_ms=mean_onset_err,
))
# Unmatched
unmatched_gt = [n for n in gt_samples if n not in matched_gt]
unmatched_clusters = [c.label for c in extracted_clusters
if c.cluster_id not in matched_clusters]
# Aggregate metrics
if matches:
mean_si_sdr = np.mean([m.si_sdr for m in matches])
mean_mfcc = np.mean([m.mfcc_distance for m in matches])
mean_env = np.mean([m.envelope_corr for m in matches])
mean_score = np.mean([m.sample_score for m in matches])
mean_onset = np.mean([m.onset_precision_ms for m in matches])
else:
mean_si_sdr = -np.inf
mean_mfcc = np.inf
mean_env = 0.0
mean_score = 0.0
mean_onset = 50.0
# Hit count accuracy: how close are extracted counts to GT counts
total_gt = sum(gt_hit_counts.values())
total_ext = sum(ext_hit_counts.values())
hit_acc = 1.0 - abs(total_gt - total_ext) / (total_gt + 1e-8)
hit_acc = max(0, hit_acc)
# Overall composite score
# Weights: SI-SDR 25%, sample_score 25%, env_corr 20%, onset 15%, coverage 15%
coverage = len(matched_gt) / (len(gt_samples) + 1e-8)
si_sdr_norm = np.clip((mean_si_sdr + 5) / 25, 0, 1) # -5dB→0, 20dB→1
env_norm = np.clip(mean_env, 0, 1)
onset_norm = np.clip(1.0 - mean_onset / 50.0, 0, 1) # 0ms→1, 50ms→0
score_norm = mean_score / 100.0
overall = (si_sdr_norm * 0.25 + score_norm * 0.25 + env_norm * 0.20 +
onset_norm * 0.15 + coverage * 0.15) * 100
return EvalReport(
matches=matches,
unmatched_gt=unmatched_gt,
unmatched_clusters=unmatched_clusters,
mean_si_sdr=float(mean_si_sdr),
mean_mfcc_dist=float(mean_mfcc),
mean_env_corr=float(mean_env),
mean_sample_score=float(mean_score),
mean_onset_error_ms=float(mean_onset),
hit_count_accuracy=float(hit_acc),
overall_score=float(overall),
params_used=pipeline_params or {},
)
def report_to_dict(report: EvalReport) -> dict:
"""Convert eval report to JSON-serializable dict."""
return {
'overall_score': report.overall_score,
'mean_si_sdr': report.mean_si_sdr,
'mean_mfcc_dist': report.mean_mfcc_dist,
'mean_env_corr': report.mean_env_corr,
'mean_sample_score': report.mean_sample_score,
'mean_onset_error_ms': report.mean_onset_error_ms,
'hit_count_accuracy': report.hit_count_accuracy,
'n_matched': len(report.matches),
'n_unmatched_gt': len(report.unmatched_gt),
'n_unmatched_clusters': len(report.unmatched_clusters),
'unmatched_gt': report.unmatched_gt,
'matches': [
{
'cluster': m.cluster_label,
'gt': m.gt_name,
'si_sdr': m.si_sdr,
'mfcc_dist': m.mfcc_distance,
'env_corr': m.envelope_corr,
'score': m.sample_score,
'onset_ms': m.onset_precision_ms,
}
for m in report.matches
],
'params': report.params_used,
}