rikhoffbauer2 commited on
Commit
26de08e
·
verified ·
1 Parent(s): ab7ecaf

Add evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation.py +265 -0
evaluation.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation engine: compare extracted drum samples against ground truth.
3
+
4
+ Runs the full pipeline on a synthetic song with known samples, then
5
+ matches extracted clusters to ground-truth samples and computes metrics.
6
+ """
7
+
8
+ import numpy as np
9
+ import librosa
10
+ import json
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+ from collections import defaultdict
14
+
15
+ from quality_metrics import (
16
+ compute_all_reference_metrics,
17
+ drum_sample_score,
18
+ compute_si_sdr,
19
+ compute_envelope_correlation,
20
+ compute_mfcc_distance,
21
+ )
22
+
23
+
24
+ @dataclass
25
+ class MatchResult:
26
+ """Result of matching one extracted cluster to a ground-truth sample."""
27
+ cluster_label: str
28
+ gt_name: str
29
+ si_sdr: float
30
+ mfcc_distance: float
31
+ envelope_corr: float
32
+ spectral_convergence: float
33
+ sample_score: float # production quality score
34
+ n_hits_extracted: int
35
+ n_hits_gt: int
36
+ onset_precision_ms: float # mean onset error
37
+
38
+
39
+ @dataclass
40
+ class EvalReport:
41
+ """Full evaluation report for one extraction run."""
42
+ matches: list # [MatchResult]
43
+ unmatched_gt: list # GT samples with no extraction match
44
+ unmatched_clusters: list # Extracted clusters with no GT match
45
+ mean_si_sdr: float
46
+ mean_mfcc_dist: float
47
+ mean_env_corr: float
48
+ mean_sample_score: float
49
+ mean_onset_error_ms: float
50
+ hit_count_accuracy: float # extracted vs GT hit counts
51
+ overall_score: float # composite [0, 100]
52
+ params_used: dict # pipeline params that produced this result
53
+
54
+
55
+ def match_sample_to_gt(extracted: np.ndarray, gt_samples: dict,
56
+ sr: int) -> tuple[str, float]:
57
+ """Find the best-matching ground-truth sample for an extracted sample.
58
+ Uses MFCC + envelope correlation for matching."""
59
+ best_name = None
60
+ best_score = -np.inf
61
+
62
+ for gt_name, gt_audio in gt_samples.items():
63
+ # Compute matching score
64
+ n = min(len(extracted), len(gt_audio))
65
+ if n < 100:
66
+ continue
67
+
68
+ mfcc_dist = compute_mfcc_distance(gt_audio[:n], extracted[:n], sr)
69
+ env_corr = compute_envelope_correlation(gt_audio[:n], extracted[:n])
70
+
71
+ # Combined matching score (lower mfcc_dist + higher env_corr = better)
72
+ score = env_corr - mfcc_dist / 50.0 # normalize mfcc to similar scale
73
+
74
+ if score > best_score:
75
+ best_score = score
76
+ best_name = gt_name
77
+
78
+ return best_name, best_score
79
+
80
+
81
+ def compute_onset_errors(extracted_hits: list, gt_hits: list,
82
+ sample_name: str, tolerance_ms: float = 50.0) -> list:
83
+ """Compute onset time errors between extracted and GT hits for a given sample type.
84
+ Returns list of (gt_time, nearest_extracted_time, error_ms)."""
85
+ gt_times = [h['onset'] for h in gt_hits if h['sample'] == sample_name]
86
+ ext_times = sorted([h.onset_time for h in extracted_hits
87
+ if sample_name in h.rough_label.lower()])
88
+
89
+ errors = []
90
+ for gt_t in gt_times:
91
+ if not ext_times:
92
+ errors.append((gt_t, None, tolerance_ms))
93
+ continue
94
+ # Find nearest extracted onset
95
+ diffs = [abs(gt_t - et) for et in ext_times]
96
+ min_idx = np.argmin(diffs)
97
+ error_ms = diffs[min_idx] * 1000
98
+ errors.append((gt_t, ext_times[min_idx], error_ms))
99
+
100
+ return errors
101
+
102
+
103
+ def evaluate_extraction(
104
+ extracted_clusters: list,
105
+ gt_samples: dict, # {name: np.ndarray}
106
+ gt_hit_map: list, # [{sample, onset, velocity}, ...]
107
+ sr: int,
108
+ all_hits: list = None, # all DrumHit objects for onset analysis
109
+ pipeline_params: dict = None,
110
+ ) -> EvalReport:
111
+ """
112
+ Evaluate extracted clusters against ground truth.
113
+
114
+ Args:
115
+ extracted_clusters: list of DrumCluster from the pipeline
116
+ gt_samples: {name: audio_array} ground-truth one-shots
117
+ gt_hit_map: [{sample, onset, velocity}] from synthetic generator
118
+ sr: sample rate
119
+ all_hits: all DrumHit objects (for onset precision analysis)
120
+ pipeline_params: the params used for this extraction run
121
+ """
122
+ matches = []
123
+ matched_gt = set()
124
+ matched_clusters = set()
125
+
126
+ # Count GT hits per sample type
127
+ gt_hit_counts = defaultdict(int)
128
+ for h in gt_hit_map:
129
+ gt_hit_counts[h['sample']] += 1
130
+
131
+ # Count extracted hits per rough label
132
+ ext_hit_counts = defaultdict(int)
133
+ for cluster in extracted_clusters:
134
+ # Extract base label (e.g. "kick" from "kick_0")
135
+ base = cluster.label.rsplit('_', 1)[0]
136
+ ext_hit_counts[base] += cluster.count
137
+
138
+ # For each cluster, find best GT match
139
+ for cluster in extracted_clusters:
140
+ best_hit = cluster.best_hit
141
+ gt_name, match_score = match_sample_to_gt(best_hit.audio, gt_samples, sr)
142
+
143
+ if gt_name is None:
144
+ continue
145
+
146
+ matched_gt.add(gt_name)
147
+ matched_clusters.add(cluster.cluster_id)
148
+
149
+ # Compute reference metrics
150
+ gt_audio = gt_samples[gt_name]
151
+ ext_audio = best_hit.audio
152
+ n = min(len(gt_audio), len(ext_audio))
153
+
154
+ si_sdr = compute_si_sdr(gt_audio[:n], ext_audio[:n])
155
+ mfcc_dist = compute_mfcc_distance(gt_audio[:n], ext_audio[:n], sr)
156
+ env_corr = compute_envelope_correlation(gt_audio[:n], ext_audio[:n])
157
+ ref_metrics = compute_all_reference_metrics(gt_audio[:n], ext_audio[:n], sr)
158
+
159
+ # Sample quality score
160
+ base_label = cluster.label.rsplit('_', 1)[0]
161
+ score = drum_sample_score(ext_audio, sr, base_label)
162
+
163
+ # Onset precision (if we have hit data)
164
+ onset_errors = []
165
+ if all_hits:
166
+ errors = compute_onset_errors(
167
+ [h for h in all_hits if base_label in h.rough_label.lower()],
168
+ gt_hit_map, gt_name
169
+ )
170
+ onset_errors = [e[2] for e in errors if e[1] is not None]
171
+
172
+ mean_onset_err = np.mean(onset_errors) if onset_errors else 50.0
173
+
174
+ matches.append(MatchResult(
175
+ cluster_label=cluster.label,
176
+ gt_name=gt_name,
177
+ si_sdr=si_sdr,
178
+ mfcc_distance=mfcc_dist,
179
+ envelope_corr=env_corr,
180
+ spectral_convergence=ref_metrics['Spectral Convergence'],
181
+ sample_score=score['total'],
182
+ n_hits_extracted=ext_hit_counts.get(base_label, 0),
183
+ n_hits_gt=gt_hit_counts.get(gt_name, 0),
184
+ onset_precision_ms=mean_onset_err,
185
+ ))
186
+
187
+ # Unmatched
188
+ unmatched_gt = [n for n in gt_samples if n not in matched_gt]
189
+ unmatched_clusters = [c.label for c in extracted_clusters
190
+ if c.cluster_id not in matched_clusters]
191
+
192
+ # Aggregate metrics
193
+ if matches:
194
+ mean_si_sdr = np.mean([m.si_sdr for m in matches])
195
+ mean_mfcc = np.mean([m.mfcc_distance for m in matches])
196
+ mean_env = np.mean([m.envelope_corr for m in matches])
197
+ mean_score = np.mean([m.sample_score for m in matches])
198
+ mean_onset = np.mean([m.onset_precision_ms for m in matches])
199
+ else:
200
+ mean_si_sdr = -np.inf
201
+ mean_mfcc = np.inf
202
+ mean_env = 0.0
203
+ mean_score = 0.0
204
+ mean_onset = 50.0
205
+
206
+ # Hit count accuracy: how close are extracted counts to GT counts
207
+ total_gt = sum(gt_hit_counts.values())
208
+ total_ext = sum(ext_hit_counts.values())
209
+ hit_acc = 1.0 - abs(total_gt - total_ext) / (total_gt + 1e-8)
210
+ hit_acc = max(0, hit_acc)
211
+
212
+ # Overall composite score
213
+ # Weights: SI-SDR 25%, sample_score 25%, env_corr 20%, onset 15%, coverage 15%
214
+ coverage = len(matched_gt) / (len(gt_samples) + 1e-8)
215
+ si_sdr_norm = np.clip((mean_si_sdr + 5) / 25, 0, 1) # -5dB→0, 20dB→1
216
+ env_norm = np.clip(mean_env, 0, 1)
217
+ onset_norm = np.clip(1.0 - mean_onset / 50.0, 0, 1) # 0ms→1, 50ms→0
218
+ score_norm = mean_score / 100.0
219
+
220
+ overall = (si_sdr_norm * 0.25 + score_norm * 0.25 + env_norm * 0.20 +
221
+ onset_norm * 0.15 + coverage * 0.15) * 100
222
+
223
+ return EvalReport(
224
+ matches=matches,
225
+ unmatched_gt=unmatched_gt,
226
+ unmatched_clusters=unmatched_clusters,
227
+ mean_si_sdr=float(mean_si_sdr),
228
+ mean_mfcc_dist=float(mean_mfcc),
229
+ mean_env_corr=float(mean_env),
230
+ mean_sample_score=float(mean_score),
231
+ mean_onset_error_ms=float(mean_onset),
232
+ hit_count_accuracy=float(hit_acc),
233
+ overall_score=float(overall),
234
+ params_used=pipeline_params or {},
235
+ )
236
+
237
+
238
+ def report_to_dict(report: EvalReport) -> dict:
239
+ """Convert eval report to JSON-serializable dict."""
240
+ return {
241
+ 'overall_score': report.overall_score,
242
+ 'mean_si_sdr': report.mean_si_sdr,
243
+ 'mean_mfcc_dist': report.mean_mfcc_dist,
244
+ 'mean_env_corr': report.mean_env_corr,
245
+ 'mean_sample_score': report.mean_sample_score,
246
+ 'mean_onset_error_ms': report.mean_onset_error_ms,
247
+ 'hit_count_accuracy': report.hit_count_accuracy,
248
+ 'n_matched': len(report.matches),
249
+ 'n_unmatched_gt': len(report.unmatched_gt),
250
+ 'n_unmatched_clusters': len(report.unmatched_clusters),
251
+ 'unmatched_gt': report.unmatched_gt,
252
+ 'matches': [
253
+ {
254
+ 'cluster': m.cluster_label,
255
+ 'gt': m.gt_name,
256
+ 'si_sdr': m.si_sdr,
257
+ 'mfcc_dist': m.mfcc_distance,
258
+ 'env_corr': m.envelope_corr,
259
+ 'score': m.sample_score,
260
+ 'onset_ms': m.onset_precision_ms,
261
+ }
262
+ for m in report.matches
263
+ ],
264
+ 'params': report.params_used,
265
+ }