File size: 9,630 Bytes
26de08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1f8d9
26de08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1f8d9
26de08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Evaluation engine: compare extracted drum samples against ground truth.

Runs the full pipeline on a synthetic song with known samples, then
matches extracted clusters to ground-truth samples and computes metrics.
"""

import numpy as np
import librosa
import json
from dataclasses import dataclass
from typing import Optional
from collections import defaultdict

from quality_metrics import (
    compute_all_reference_metrics,
    drum_sample_score,
    compute_si_sdr,
    compute_envelope_correlation,
    compute_mfcc_distance,
)


@dataclass
class MatchResult:
    """Result of matching one extracted cluster to a ground-truth sample."""
    cluster_label: str
    gt_name: str
    si_sdr: float
    mfcc_distance: float
    envelope_corr: float
    spectral_convergence: float
    sample_score: float        # production quality score
    n_hits_extracted: int
    n_hits_gt: int
    onset_precision_ms: float  # mean onset error


@dataclass 
class EvalReport:
    """Full evaluation report for one extraction run."""
    matches: list                      # [MatchResult]
    unmatched_gt: list                 # GT samples with no extraction match
    unmatched_clusters: list           # Extracted clusters with no GT match
    mean_si_sdr: float
    mean_mfcc_dist: float
    mean_env_corr: float
    mean_sample_score: float
    mean_onset_error_ms: float
    hit_count_accuracy: float          # extracted vs GT hit counts
    overall_score: float               # composite [0, 100]
    params_used: dict                  # pipeline params that produced this result


def match_sample_to_gt(extracted: np.ndarray, gt_samples: dict,
                       sr: int) -> tuple[str, float]:
    """Find the best-matching ground-truth sample for an extracted sample.
    Uses MFCC + envelope correlation for matching."""
    best_name = None
    best_score = -np.inf

    for gt_name, gt_audio in gt_samples.items():
        # Compute matching score
        n = min(len(extracted), len(gt_audio))
        if n < 100:
            continue

        mfcc_dist = compute_mfcc_distance(gt_audio[:n], extracted[:n], sr)
        env_corr = compute_envelope_correlation(gt_audio[:n], extracted[:n])

        # Combined matching score (lower mfcc_dist + higher env_corr = better)
        score = env_corr - mfcc_dist / 50.0  # normalize mfcc to similar scale

        if score > best_score:
            best_score = score
            best_name = gt_name

    return best_name, best_score


def compute_onset_errors(extracted_hits: list, gt_hits: list,
                          sample_name: str, tolerance_ms: float = 50.0) -> list:
    """Compute onset time errors between extracted and GT hits for a given sample type.
    Returns list of (gt_time, nearest_extracted_time, error_ms)."""
    gt_times = [h['onset'] for h in gt_hits if h['sample'] == sample_name]
    ext_times = sorted([h.onset_time for h in extracted_hits
                        if sample_name in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()])

    errors = []
    for gt_t in gt_times:
        if not ext_times:
            errors.append((gt_t, None, tolerance_ms))
            continue
        # Find nearest extracted onset
        diffs = [abs(gt_t - et) for et in ext_times]
        min_idx = np.argmin(diffs)
        error_ms = diffs[min_idx] * 1000
        errors.append((gt_t, ext_times[min_idx], error_ms))

    return errors


def evaluate_extraction(
    extracted_clusters: list,
    gt_samples: dict,           # {name: np.ndarray}
    gt_hit_map: list,           # [{sample, onset, velocity}, ...]
    sr: int,
    all_hits: list = None,      # all DrumHit objects for onset analysis
    pipeline_params: dict = None,
) -> EvalReport:
    """
    Evaluate extracted clusters against ground truth.
    
    Args:
        extracted_clusters: list of DrumCluster from the pipeline
        gt_samples: {name: audio_array} ground-truth one-shots
        gt_hit_map: [{sample, onset, velocity}] from synthetic generator
        sr: sample rate
        all_hits: all DrumHit objects (for onset precision analysis)
        pipeline_params: the params used for this extraction run
    """
    matches = []
    matched_gt = set()
    matched_clusters = set()

    # Count GT hits per sample type
    gt_hit_counts = defaultdict(int)
    for h in gt_hit_map:
        gt_hit_counts[h['sample']] += 1

    # Count extracted hits per rough label
    ext_hit_counts = defaultdict(int)
    for cluster in extracted_clusters:
        # Extract base label (e.g. "kick" from "kick_0")
        base = cluster.label.rsplit('_', 1)[0]
        ext_hit_counts[base] += cluster.count

    # For each cluster, find best GT match
    for cluster in extracted_clusters:
        best_hit = cluster.best_hit
        gt_name, match_score = match_sample_to_gt(best_hit.audio, gt_samples, sr)

        if gt_name is None:
            continue

        matched_gt.add(gt_name)
        matched_clusters.add(cluster.cluster_id)

        # Compute reference metrics
        gt_audio = gt_samples[gt_name]
        ext_audio = best_hit.audio
        n = min(len(gt_audio), len(ext_audio))

        si_sdr = compute_si_sdr(gt_audio[:n], ext_audio[:n])
        mfcc_dist = compute_mfcc_distance(gt_audio[:n], ext_audio[:n], sr)
        env_corr = compute_envelope_correlation(gt_audio[:n], ext_audio[:n])
        ref_metrics = compute_all_reference_metrics(gt_audio[:n], ext_audio[:n], sr)

        # Sample quality score
        base_label = cluster.label.rsplit('_', 1)[0]
        score = drum_sample_score(ext_audio, sr, base_label)

        # Onset precision (if we have hit data)
        onset_errors = []
        if all_hits:
            errors = compute_onset_errors(
                [h for h in all_hits if base_label in getattr(h, 'label', getattr(h, 'rough_label', '')).lower()],
                gt_hit_map, gt_name
            )
            onset_errors = [e[2] for e in errors if e[1] is not None]

        mean_onset_err = np.mean(onset_errors) if onset_errors else 50.0

        matches.append(MatchResult(
            cluster_label=cluster.label,
            gt_name=gt_name,
            si_sdr=si_sdr,
            mfcc_distance=mfcc_dist,
            envelope_corr=env_corr,
            spectral_convergence=ref_metrics['Spectral Convergence'],
            sample_score=score['total'],
            n_hits_extracted=ext_hit_counts.get(base_label, 0),
            n_hits_gt=gt_hit_counts.get(gt_name, 0),
            onset_precision_ms=mean_onset_err,
        ))

    # Unmatched
    unmatched_gt = [n for n in gt_samples if n not in matched_gt]
    unmatched_clusters = [c.label for c in extracted_clusters
                          if c.cluster_id not in matched_clusters]

    # Aggregate metrics
    if matches:
        mean_si_sdr = np.mean([m.si_sdr for m in matches])
        mean_mfcc = np.mean([m.mfcc_distance for m in matches])
        mean_env = np.mean([m.envelope_corr for m in matches])
        mean_score = np.mean([m.sample_score for m in matches])
        mean_onset = np.mean([m.onset_precision_ms for m in matches])
    else:
        mean_si_sdr = -np.inf
        mean_mfcc = np.inf
        mean_env = 0.0
        mean_score = 0.0
        mean_onset = 50.0

    # Hit count accuracy: how close are extracted counts to GT counts
    total_gt = sum(gt_hit_counts.values())
    total_ext = sum(ext_hit_counts.values())
    hit_acc = 1.0 - abs(total_gt - total_ext) / (total_gt + 1e-8)
    hit_acc = max(0, hit_acc)

    # Overall composite score
    # Weights: SI-SDR 25%, sample_score 25%, env_corr 20%, onset 15%, coverage 15%
    coverage = len(matched_gt) / (len(gt_samples) + 1e-8)
    si_sdr_norm = np.clip((mean_si_sdr + 5) / 25, 0, 1)  # -5dB→0, 20dB→1
    env_norm = np.clip(mean_env, 0, 1)
    onset_norm = np.clip(1.0 - mean_onset / 50.0, 0, 1)   # 0ms→1, 50ms→0
    score_norm = mean_score / 100.0

    overall = (si_sdr_norm * 0.25 + score_norm * 0.25 + env_norm * 0.20 +
               onset_norm * 0.15 + coverage * 0.15) * 100

    return EvalReport(
        matches=matches,
        unmatched_gt=unmatched_gt,
        unmatched_clusters=unmatched_clusters,
        mean_si_sdr=float(mean_si_sdr),
        mean_mfcc_dist=float(mean_mfcc),
        mean_env_corr=float(mean_env),
        mean_sample_score=float(mean_score),
        mean_onset_error_ms=float(mean_onset),
        hit_count_accuracy=float(hit_acc),
        overall_score=float(overall),
        params_used=pipeline_params or {},
    )


def report_to_dict(report: EvalReport) -> dict:
    """Convert eval report to JSON-serializable dict."""
    return {
        'overall_score': report.overall_score,
        'mean_si_sdr': report.mean_si_sdr,
        'mean_mfcc_dist': report.mean_mfcc_dist,
        'mean_env_corr': report.mean_env_corr,
        'mean_sample_score': report.mean_sample_score,
        'mean_onset_error_ms': report.mean_onset_error_ms,
        'hit_count_accuracy': report.hit_count_accuracy,
        'n_matched': len(report.matches),
        'n_unmatched_gt': len(report.unmatched_gt),
        'n_unmatched_clusters': len(report.unmatched_clusters),
        'unmatched_gt': report.unmatched_gt,
        'matches': [
            {
                'cluster': m.cluster_label,
                'gt': m.gt_name,
                'si_sdr': m.si_sdr,
                'mfcc_dist': m.mfcc_distance,
                'env_corr': m.envelope_corr,
                'score': m.sample_score,
                'onset_ms': m.onset_precision_ms,
            }
            for m in report.matches
        ],
        'params': report.params_used,
    }