rikhoffbauer2 commited on
Commit
2a334ed
Β·
verified Β·
1 Parent(s): d34b37f

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +570 -0
app.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for Drum Sample Extractor.
3
+
4
+ Three tabs:
5
+ 1. Extract β€” Upload audio, run the pipeline, listen to extracted samples
6
+ 2. Evaluate β€” Generate synthetic songs, compare extraction to ground truth
7
+ 3. Auto-Optimize β€” Run autonomous improvement loop with live progress
8
+ """
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib
14
+ matplotlib.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+ import json
17
+ import time
18
+ import sys
19
+ import os
20
+ import io
21
+ import tempfile
22
+ import soundfile as sf
23
+ import librosa
24
+ import traceback
25
+
26
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from drum_extractor import (
29
+ extract_drums_demucs, detect_onsets, classify_and_separate_hits,
30
+ compute_librosa_embeddings, cluster_hits, select_best_representatives,
31
+ synthesize_from_cluster, DrumCluster,
32
+ )
33
+ from quality_metrics import drum_sample_score, compute_all_reference_metrics
34
+ from synth_generator import generate_test_song
35
+ from evaluation import evaluate_extraction, report_to_dict
36
+ from optimizer import run_optimization_loop, PipelineParams, OptimizerState
37
+
38
+
39
+ # ─────────────────────────────────────────────────────────────────────────────
40
+ # Helper functions
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
+
43
+ def audio_to_tuple(audio: np.ndarray, sr: int) -> tuple:
44
+ """Convert audio array to Gradio-compatible (sr, data) tuple."""
45
+ if audio.dtype != np.float32:
46
+ audio = audio.astype(np.float32)
47
+ # Normalize to prevent clipping
48
+ peak = np.abs(audio).max()
49
+ if peak > 0:
50
+ audio = audio / peak * 0.95
51
+ return (sr, audio)
52
+
53
+
54
+ def make_waveform_plot(audio_dict: dict, sr: int, title: str = "Waveforms") -> plt.Figure:
55
+ """Create a multi-panel waveform plot."""
56
+ n = len(audio_dict)
57
+ if n == 0:
58
+ fig, ax = plt.subplots(figsize=(10, 2))
59
+ ax.text(0.5, 0.5, "No audio to display", ha='center', va='center')
60
+ return fig
61
+
62
+ fig, axes = plt.subplots(n, 1, figsize=(10, max(2, n * 1.5)), squeeze=False)
63
+ fig.suptitle(title, fontsize=12, fontweight='bold')
64
+
65
+ for idx, (name, audio) in enumerate(audio_dict.items()):
66
+ ax = axes[idx, 0]
67
+ t = np.arange(len(audio)) / sr
68
+ ax.plot(t, audio, linewidth=0.3, color='#2196F3')
69
+ ax.set_ylabel(name, fontsize=8)
70
+ ax.set_xlim(0, len(audio) / sr)
71
+ ax.set_ylim(-1, 1)
72
+ if idx < n - 1:
73
+ ax.set_xticklabels([])
74
+ else:
75
+ ax.set_xlabel("Time (s)")
76
+
77
+ plt.tight_layout()
78
+ return fig
79
+
80
+
81
+ def make_metrics_plot(history: list) -> plt.Figure:
82
+ """Plot optimization history."""
83
+ if not history:
84
+ fig, ax = plt.subplots(figsize=(10, 4))
85
+ ax.text(0.5, 0.5, "No data yet", ha='center', va='center')
86
+ return fig
87
+
88
+ iters = [r.iteration for r in history]
89
+ scores = [r.overall_score for r in history]
90
+
91
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
92
+ fig.suptitle("Optimization Progress", fontsize=14, fontweight='bold')
93
+
94
+ # Overall score
95
+ ax = axes[0, 0]
96
+ ax.plot(iters, scores, 'b-o', linewidth=2, markersize=4)
97
+ ax.set_ylabel("Overall Score")
98
+ ax.set_title("Overall Score (/100)")
99
+ ax.grid(True, alpha=0.3)
100
+ best_idx = np.argmax(scores)
101
+ ax.scatter([iters[best_idx]], [scores[best_idx]], color='red', s=100, zorder=5, label=f'Best: {scores[best_idx]:.1f}')
102
+ ax.legend()
103
+
104
+ # SI-SDR
105
+ ax = axes[0, 1]
106
+ si_sdrs = [r.eval_report.get('mean_si_sdr', -50) if isinstance(r.eval_report, dict) else -50 for r in history]
107
+ ax.plot(iters, si_sdrs, 'g-o', linewidth=2, markersize=4)
108
+ ax.set_ylabel("SI-SDR (dB)")
109
+ ax.set_title("Mean SI-SDR")
110
+ ax.grid(True, alpha=0.3)
111
+
112
+ # Sample score
113
+ ax = axes[1, 0]
114
+ sample_scores = [r.eval_report.get('mean_sample_score', 0) if isinstance(r.eval_report, dict) else 0 for r in history]
115
+ ax.plot(iters, sample_scores, 'r-o', linewidth=2, markersize=4)
116
+ ax.set_ylabel("Sample Score (/100)")
117
+ ax.set_title("Mean Sample Quality Score")
118
+ ax.grid(True, alpha=0.3)
119
+
120
+ # Parameter evolution
121
+ ax = axes[1, 1]
122
+ thresholds = [r.params.get('energy_threshold_db', -40) for r in history]
123
+ ax.plot(iters, thresholds, 'm-o', linewidth=2, markersize=4, label='energy_thresh (dB)')
124
+ ax.set_ylabel("Value")
125
+ ax.set_title("Parameter Evolution")
126
+ ax.legend(fontsize=8)
127
+ ax.grid(True, alpha=0.3)
128
+
129
+ plt.tight_layout()
130
+ return fig
131
+
132
+
133
+ # ─────────────────────────────────────────────────────────────────────────────
134
+ # Tab 1: Extract
135
+ # ─────────────────────────────────────────────────────────────────────────────
136
+
137
+ def run_extraction(audio_input, progress=gr.Progress()):
138
+ """Run drum extraction on uploaded audio."""
139
+ if audio_input is None:
140
+ return (None,) * 10
141
+
142
+ progress(0.0, desc="Loading audio...")
143
+ sr_in, data = audio_input
144
+ data = data.astype(np.float32)
145
+ if data.ndim > 1:
146
+ data = data.mean(axis=1)
147
+ peak = np.abs(data).max()
148
+ if peak > 0:
149
+ data = data / peak
150
+
151
+ # Save to temp file for Demucs
152
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
153
+ sf.write(f.name, data, sr_in)
154
+ tmp_path = f.name
155
+
156
+ try:
157
+ # Stage 1: Demucs
158
+ progress(0.1, desc="Extracting drum stem (Demucs)...")
159
+ drums, drums_sr = extract_drums_demucs(tmp_path, device="cpu")
160
+
161
+ # Stage 2: Onsets
162
+ progress(0.4, desc="Detecting onsets...")
163
+ hits = detect_onsets(drums, drums_sr)
164
+
165
+ if len(hits) == 0:
166
+ return (audio_to_tuple(drums, drums_sr),) + (None,) * 9
167
+
168
+ # Stage 3: Classify & separate
169
+ progress(0.5, desc="Classifying hits...")
170
+ hits = classify_and_separate_hits(hits, separate_overlaps=True)
171
+
172
+ # Stage 4: Embed & cluster
173
+ progress(0.6, desc="Clustering similar hits...")
174
+ embeddings = compute_librosa_embeddings(hits)
175
+ clusters = cluster_hits(hits, embeddings)
176
+
177
+ # Stage 5: Select best (with quality scoring)
178
+ progress(0.7, desc="Selecting best representatives...")
179
+ for cluster in clusters:
180
+ if cluster.count == 1:
181
+ cluster.best_hit_idx = 0
182
+ continue
183
+ scores = []
184
+ base_label = cluster.label.rsplit('_', 1)[0]
185
+ for hit in cluster.hits:
186
+ score = drum_sample_score(hit.audio, hit.sr, base_label)
187
+ scores.append(score['total'])
188
+ cluster.best_hit_idx = int(np.argmax(scores))
189
+
190
+ # Stage 6: Synthesis
191
+ progress(0.8, desc="Synthesizing optimal samples...")
192
+ for cluster in clusters:
193
+ if cluster.count >= 2:
194
+ cluster.synthesized = synthesize_from_cluster(cluster)
195
+
196
+ progress(0.9, desc="Building results...")
197
+
198
+ # Build outputs
199
+ drums_out = audio_to_tuple(drums, drums_sr)
200
+
201
+ # Collect up to 8 best samples (sorted by cluster size)
202
+ sorted_clusters = sorted(clusters, key=lambda c: c.count, reverse=True)[:8]
203
+ sample_outputs = []
204
+ for c in sorted_clusters:
205
+ sample_outputs.append(audio_to_tuple(c.best_hit.audio, c.best_hit.sr))
206
+
207
+ # Pad to 8
208
+ while len(sample_outputs) < 8:
209
+ sample_outputs.append(None)
210
+
211
+ # Metrics table
212
+ rows = []
213
+ for c in sorted_clusters:
214
+ best = c.best_hit
215
+ base_label = c.label.rsplit('_', 1)[0]
216
+ score = drum_sample_score(best.audio, best.sr, base_label)
217
+ rows.append({
218
+ 'Cluster': c.label,
219
+ 'Hits': c.count,
220
+ 'Score': f"{score['total']:.1f}",
221
+ 'Completeness': f"{score['completeness']:.2f}",
222
+ 'Cleanness': f"{score['cleanness']:.2f}",
223
+ 'Onset': f"{score['onset_quality']:.2f}",
224
+ 'Duration (ms)': f"{best.duration * 1000:.0f}",
225
+ })
226
+ metrics_df = pd.DataFrame(rows)
227
+
228
+ # Waveform plot
229
+ waveforms = {c.label: c.best_hit.audio for c in sorted_clusters[:6]}
230
+ fig = make_waveform_plot(waveforms, drums_sr, "Extracted Samples")
231
+
232
+ progress(1.0, desc="Done!")
233
+ return (drums_out,) + tuple(sample_outputs) + (metrics_df, fig)
234
+
235
+ finally:
236
+ os.unlink(tmp_path)
237
+
238
+
239
+ # ─────────────────────────────────────────────────────────────────────────────
240
+ # Tab 2: Evaluate
241
+ # ─────────────────────────────────────────────────────────────────────────────
242
+
243
+ def run_evaluation(pattern, bpm, bars, progress=gr.Progress()):
244
+ """Generate synthetic song, extract, evaluate against ground truth."""
245
+ progress(0.0, desc="Generating synthetic song...")
246
+
247
+ song = generate_test_song(
248
+ pattern_name=pattern,
249
+ bars=int(bars),
250
+ bpm=float(bpm),
251
+ variation='medium',
252
+ seed=42,
253
+ )
254
+
255
+ progress(0.2, desc="Running extraction pipeline...")
256
+ hits = detect_onsets(song.drums_only, song.sr)
257
+
258
+ if len(hits) == 0:
259
+ return None, None, None, None, None, "No hits detected"
260
+
261
+ hits = classify_and_separate_hits(hits, separate_overlaps=True)
262
+ embeddings = compute_librosa_embeddings(hits)
263
+ clusters = cluster_hits(hits, embeddings)
264
+
265
+ # Quality-based selection
266
+ for cluster in clusters:
267
+ if cluster.count == 1:
268
+ cluster.best_hit_idx = 0
269
+ continue
270
+ scores = []
271
+ base_label = cluster.label.rsplit('_', 1)[0]
272
+ for hit in cluster.hits:
273
+ score = drum_sample_score(hit.audio, hit.sr, base_label)
274
+ scores.append(score['total'])
275
+ cluster.best_hit_idx = int(np.argmax(scores))
276
+
277
+ for cluster in clusters:
278
+ if cluster.count >= 2:
279
+ cluster.synthesized = synthesize_from_cluster(cluster)
280
+
281
+ progress(0.6, desc="Evaluating against ground truth...")
282
+ gt_samples = {name: s.audio for name, s in song.samples.items()}
283
+ gt_hit_map = [
284
+ {'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
285
+ for h in song.hits
286
+ ]
287
+
288
+ report = evaluate_extraction(
289
+ extracted_clusters=clusters,
290
+ gt_samples=gt_samples,
291
+ gt_hit_map=gt_hit_map,
292
+ sr=song.sr,
293
+ all_hits=hits,
294
+ )
295
+
296
+ progress(0.8, desc="Building report...")
297
+
298
+ # Mix audio
299
+ mix_out = audio_to_tuple(song.mix, song.sr)
300
+ drums_out = audio_to_tuple(song.drums_only, song.sr)
301
+
302
+ # Metrics table
303
+ summary_rows = [
304
+ {'Metric': 'Overall Score', 'Value': f"{report.overall_score:.1f}/100",
305
+ 'Target': '> 70'},
306
+ {'Metric': 'SI-SDR', 'Value': f"{report.mean_si_sdr:.1f} dB",
307
+ 'Target': '> 10 dB'},
308
+ {'Metric': 'Sample Score', 'Value': f"{report.mean_sample_score:.1f}/100",
309
+ 'Target': '> 60'},
310
+ {'Metric': 'Envelope Corr', 'Value': f"{report.mean_env_corr:.3f}",
311
+ 'Target': '> 0.9'},
312
+ {'Metric': 'Onset Error', 'Value': f"{report.mean_onset_error_ms:.1f} ms",
313
+ 'Target': '< 10 ms'},
314
+ {'Metric': 'Hit Count Acc', 'Value': f"{report.hit_count_accuracy:.2f}",
315
+ 'Target': '> 0.9'},
316
+ {'Metric': 'Coverage', 'Value': f"{len(report.matches)}/{len(gt_samples)}",
317
+ 'Target': 'All matched'},
318
+ ]
319
+ if report.unmatched_gt:
320
+ summary_rows.append({
321
+ 'Metric': '⚠ Unmatched GT', 'Value': ', '.join(report.unmatched_gt),
322
+ 'Target': 'None'
323
+ })
324
+ summary_df = pd.DataFrame(summary_rows)
325
+
326
+ # Match detail table
327
+ match_rows = []
328
+ for m in report.matches:
329
+ match_rows.append({
330
+ 'Cluster': m.cluster_label,
331
+ 'Matched GT': m.gt_name,
332
+ 'SI-SDR (dB)': f"{m.si_sdr:.1f}",
333
+ 'MFCC Dist': f"{m.mfcc_distance:.2f}",
334
+ 'Env Corr': f"{m.envelope_corr:.3f}",
335
+ 'Score': f"{m.sample_score:.1f}",
336
+ 'Onset (ms)': f"{m.onset_precision_ms:.1f}",
337
+ })
338
+ match_df = pd.DataFrame(match_rows) if match_rows else pd.DataFrame()
339
+
340
+ # GT vs Extracted waveforms comparison
341
+ fig, axes = plt.subplots(len(gt_samples), 2, figsize=(12, len(gt_samples) * 2), squeeze=False)
342
+ fig.suptitle("Ground Truth vs Best Extracted", fontsize=12, fontweight='bold')
343
+
344
+ for idx, (gt_name, gt_audio) in enumerate(gt_samples.items()):
345
+ # GT waveform
346
+ t_gt = np.arange(len(gt_audio)) / song.sr
347
+ axes[idx, 0].plot(t_gt, gt_audio, color='#4CAF50', linewidth=0.5)
348
+ axes[idx, 0].set_ylabel(gt_name, fontsize=8)
349
+ axes[idx, 0].set_ylim(-1, 1)
350
+ if idx == 0:
351
+ axes[idx, 0].set_title("Ground Truth")
352
+
353
+ # Find matching extracted sample
354
+ matching = [m for m in report.matches if m.gt_name == gt_name]
355
+ if matching:
356
+ best_match = matching[0]
357
+ ext_cluster = [c for c in clusters if c.label == best_match.cluster_label]
358
+ if ext_cluster:
359
+ ext_audio = ext_cluster[0].best_hit.audio
360
+ t_ext = np.arange(len(ext_audio)) / song.sr
361
+ axes[idx, 1].plot(t_ext, ext_audio, color='#FF9800', linewidth=0.5)
362
+ axes[idx, 1].set_ylim(-1, 1)
363
+ if idx == 0:
364
+ axes[idx, 1].set_title("Extracted")
365
+
366
+ plt.tight_layout()
367
+
368
+ progress(1.0, desc="Done!")
369
+ return mix_out, drums_out, summary_df, match_df, fig, ""
370
+
371
+
372
+ # ─────────────────────────────────────────────────────────────────────────────
373
+ # Tab 3: Auto-Optimize
374
+ # ─────────────────────────────────────────────────────────────────────────────
375
+
376
+ # Global state for optimizer (persists across calls)
377
+ _optimizer_state = None
378
+ _optimizer_log = []
379
+
380
+
381
+ def run_auto_optimize(n_iterations, progress=gr.Progress()):
382
+ """Run autonomous optimization loop."""
383
+ global _optimizer_state, _optimizer_log
384
+ _optimizer_log = []
385
+
386
+ def log_callback(msg):
387
+ _optimizer_log.append(msg)
388
+
389
+ progress(0.0, desc="Starting optimization...")
390
+
391
+ state = run_optimization_loop(
392
+ n_iterations=int(n_iterations),
393
+ patterns=['rock', 'funk', 'halftime'],
394
+ initial_params=PipelineParams(),
395
+ seed=42,
396
+ log_callback=log_callback,
397
+ )
398
+
399
+ _optimizer_state = state
400
+ progress(1.0, desc="Done!")
401
+
402
+ # Build outputs
403
+ log_text = '\n'.join(_optimizer_log)
404
+
405
+ # History table
406
+ hist_rows = []
407
+ for r in state.history:
408
+ hist_rows.append({
409
+ 'Iter': r.iteration,
410
+ 'Pattern': r.test_config.get('pattern', '?'),
411
+ 'BPM': r.test_config.get('bpm', '?'),
412
+ 'Score': f"{r.overall_score:.1f}",
413
+ 'SI-SDR': f"{r.eval_report.get('mean_si_sdr', 0):.1f}" if isinstance(r.eval_report, dict) else 'err',
414
+ 'Sample': f"{r.eval_report.get('mean_sample_score', 0):.1f}" if isinstance(r.eval_report, dict) else 'err',
415
+ 'Time (s)': f"{r.duration_seconds:.1f}",
416
+ })
417
+ hist_df = pd.DataFrame(hist_rows)
418
+
419
+ # Optimization plot
420
+ fig = make_metrics_plot(state.history)
421
+
422
+ # Best params
423
+ best_params_str = json.dumps(state.best_params, indent=2)
424
+
425
+ return log_text, hist_df, fig, best_params_str
426
+
427
+
428
+ # ─────────────────────────────────────────────────────────────────────────────
429
+ # App layout
430
+ # ─────────────────────────────────────────────────────────────────────────────
431
+
432
+ def build_app():
433
+ with gr.Blocks(
434
+ title="πŸ₯ Drum Sample Extractor",
435
+ theme=gr.themes.Soft(),
436
+ css="""
437
+ .gradio-container { max-width: 1200px !important; }
438
+ .sample-audio { min-height: 60px; }
439
+ """
440
+ ) as app:
441
+ gr.Markdown("""
442
+ # πŸ₯ Drum Sample Extractor
443
+
444
+ Extract individual drum samples from audio files using **HTDemucs** stem separation,
445
+ **multi-band onset detection**, **spectral overlap decomposition**, and
446
+ **quality-aware clustering**.
447
+
448
+ Includes a synthetic evaluation framework with autonomous parameter optimization.
449
+ """)
450
+
451
+ with gr.Tabs():
452
+ # ── Tab 1: Extract ──
453
+ with gr.Tab("🎡 Extract", id=0):
454
+ gr.Markdown("Upload an audio file to extract drum samples.")
455
+
456
+ audio_in = gr.Audio(
457
+ sources=['upload'],
458
+ type='numpy',
459
+ label='Upload Audio (MP3, WAV, FLAC)',
460
+ )
461
+ extract_btn = gr.Button("πŸ”¬ Extract Drum Samples", variant="primary", size="lg")
462
+
463
+ with gr.Row():
464
+ drums_out = gr.Audio(type='numpy', label='πŸ₯ Isolated Drum Stem', interactive=False)
465
+
466
+ gr.Markdown("### Extracted Samples")
467
+ gr.Markdown("*Best representative from each cluster, ranked by hit count:*")
468
+
469
+ with gr.Row():
470
+ s0 = gr.Audio(type='numpy', label='Sample 1', interactive=False)
471
+ s1 = gr.Audio(type='numpy', label='Sample 2', interactive=False)
472
+ s2 = gr.Audio(type='numpy', label='Sample 3', interactive=False)
473
+ s3 = gr.Audio(type='numpy', label='Sample 4', interactive=False)
474
+ with gr.Row():
475
+ s4 = gr.Audio(type='numpy', label='Sample 5', interactive=False)
476
+ s5 = gr.Audio(type='numpy', label='Sample 6', interactive=False)
477
+ s6 = gr.Audio(type='numpy', label='Sample 7', interactive=False)
478
+ s7 = gr.Audio(type='numpy', label='Sample 8', interactive=False)
479
+
480
+ gr.Markdown("### Quality Metrics")
481
+ metrics_table = gr.Dataframe(label="Cluster Quality Scores")
482
+ waveform_plot = gr.Plot(label="Waveforms")
483
+
484
+ extract_btn.click(
485
+ fn=run_extraction,
486
+ inputs=[audio_in],
487
+ outputs=[drums_out, s0, s1, s2, s3, s4, s5, s6, s7,
488
+ metrics_table, waveform_plot],
489
+ )
490
+
491
+ # ── Tab 2: Evaluate ──
492
+ with gr.Tab("πŸ“Š Evaluate", id=1):
493
+ gr.Markdown("""
494
+ ### Synthetic Evaluation
495
+ Generate a synthetic drum song with known ground-truth samples, run the extraction
496
+ pipeline, and compare results. This tells us exactly how well the system works.
497
+ """)
498
+
499
+ with gr.Row():
500
+ pattern_dd = gr.Dropdown(
501
+ choices=['rock', 'funk', 'halftime'],
502
+ value='rock',
503
+ label='Drum Pattern'
504
+ )
505
+ bpm_slider = gr.Slider(80, 200, value=120, step=2, label='BPM')
506
+ bars_slider = gr.Slider(2, 8, value=4, step=1, label='Bars')
507
+
508
+ eval_btn = gr.Button("πŸ§ͺ Generate & Evaluate", variant="primary", size="lg")
509
+
510
+ with gr.Row():
511
+ eval_mix = gr.Audio(type='numpy', label='Synthetic Mix', interactive=False)
512
+ eval_drums = gr.Audio(type='numpy', label='Drums Only', interactive=False)
513
+
514
+ gr.Markdown("### Evaluation Results")
515
+ eval_summary = gr.Dataframe(label="Summary Metrics")
516
+ eval_matches = gr.Dataframe(label="Cluster β†’ Ground Truth Matches")
517
+ eval_plot = gr.Plot(label="GT vs Extracted Comparison")
518
+ eval_status = gr.Textbox(label="Status", visible=False)
519
+
520
+ eval_btn.click(
521
+ fn=run_evaluation,
522
+ inputs=[pattern_dd, bpm_slider, bars_slider],
523
+ outputs=[eval_mix, eval_drums, eval_summary, eval_matches,
524
+ eval_plot, eval_status],
525
+ )
526
+
527
+ # ── Tab 3: Auto-Optimize ──
528
+ with gr.Tab("πŸ”„ Auto-Optimize", id=2):
529
+ gr.Markdown("""
530
+ ### Autonomous Parameter Optimization
531
+
532
+ Runs a loop: **generate** synthetic song β†’ **extract** β†’ **evaluate** against ground truth β†’
533
+ **diagnose** issues β†’ **tune** parameters β†’ repeat.
534
+
535
+ The optimizer reads evaluation metrics and makes targeted adjustments:
536
+ - High onset error β†’ tighten `pre_pad` and `min_gap`
537
+ - Missing hits β†’ lower `energy_threshold`
538
+ - Poor SI-SDR β†’ adjust overlap separation
539
+ - Low sample score β†’ rebalance selection weights
540
+ """)
541
+
542
+ with gr.Row():
543
+ n_iters = gr.Slider(2, 30, value=5, step=1,
544
+ label='Number of Iterations')
545
+ opt_btn = gr.Button("πŸš€ Run Optimization", variant="primary", size="lg")
546
+
547
+ opt_log = gr.Textbox(label="Optimization Log", lines=20,
548
+ max_lines=40)
549
+
550
+ gr.Markdown("### Results")
551
+ opt_table = gr.Dataframe(label="Iteration History")
552
+ opt_plot = gr.Plot(label="Optimization Progress")
553
+ opt_params = gr.Code(label="Best Parameters (JSON)", language="json")
554
+
555
+ opt_btn.click(
556
+ fn=run_auto_optimize,
557
+ inputs=[n_iters],
558
+ outputs=[opt_log, opt_table, opt_plot, opt_params],
559
+ )
560
+
561
+ return app
562
+
563
+
564
+ # ─────────────────────────────────────────────────────────────────────────────
565
+ # Entry point
566
+ # ─────────────────────────────────────────────────────────────────────────────
567
+
568
+ if __name__ == "__main__":
569
+ app = build_app()
570
+ app.launch(server_name="0.0.0.0", server_port=7860)