rikhoffbauer2 commited on
Commit
cadb06d
·
verified ·
1 Parent(s): 2d1fee2

v8: Lock checkboxes for auto-tune — constrain any parameter during search

Browse files
Files changed (1) hide show
  1. app.py +174 -204
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Gradio UI — Sample Extractor v7.
3
- Auto-tuning via reconstruction quality, NCC clustering, full param control.
4
  """
5
 
6
  import gradio as gr
@@ -23,21 +23,32 @@ from evaluation import evaluate_extraction
23
  from config_store import PipelineConfig, get_leaderboard
24
  from optimizer_v2 import run_optimization
25
 
26
-
27
  def audio_tuple(a, sr):
28
- a = a.astype(np.float32)
29
- pk = np.abs(a).max()
30
  if pk > 0: a = a / pk * 0.95
31
  return (sr, a)
32
 
33
 
34
- # ─── Auto-tune handler ───────────────────────────────────────────────────────
35
 
36
  def run_auto_tune(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
37
- onset_mode, progress=gr.Progress()):
38
- """Run auto-tuner on uploaded audio. Returns optimized param values."""
 
 
 
 
39
  if audio_in is None:
40
- return [gr.update()] * 7 + ["Upload audio first"]
 
 
 
 
 
 
 
 
 
41
 
42
  progress(0.0, desc="Loading audio...")
43
  sr_in, data = audio_in
@@ -54,25 +65,26 @@ def run_auto_tune(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_ove
54
  stem_audio, stem_sr = extract_stem(tmp, stem=stem_choice, device="cpu",
55
  model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
56
 
57
- progress(0.15, desc="Auto-tuning parameters (this takes a moment)...")
58
- log_lines = []
59
  best_params, best_score, log_lines = auto_tune(
60
- stem_audio, stem_sr, mode=onset_mode,
61
- log_fn=lambda m: log_lines.append(m))
62
 
63
- progress(1.0, desc=f"Done! Score: {best_score:.1f}")
64
 
65
- # Return updated slider/number values + log
66
- log_text = '\n'.join(log_lines[-30:]) # last 30 lines
67
- summary = (f"**Auto-tune complete!** Best reconstruction score: **{best_score:.1f}/100**\n\n"
68
- f"Parameters have been updated. Click **Extract Samples** to run with these settings.")
 
69
 
 
70
  return [
71
- gr.update(value=best_params.get('onset_delta', 0.12)),
72
- gr.update(value=best_params.get('energy_threshold_db', -35)),
73
- gr.update(value=best_params.get('min_gap', 0.03)),
74
- gr.update(value=best_params.get('target_min', 5)),
75
- gr.update(value=best_params.get('target_max', 20)),
76
  summary,
77
  log_text,
78
  ]
@@ -80,163 +92,115 @@ def run_auto_tune(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_ove
80
  os.unlink(tmp)
81
 
82
 
83
- # ─── Tab 1: Extract ──────────────────────────────────────────────────────────
84
 
85
  def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
86
  onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap,
87
  ncc_threshold, ncc_compare_ms, linkage, target_min, target_max,
88
  do_synthesize, progress=gr.Progress()):
89
- if audio_in is None:
90
- return [None] * 8
91
-
92
- progress(0.0, desc="Loading audio...")
93
- sr_in, data = audio_in
94
- data = data.astype(np.float32)
95
  if data.ndim > 1: data = data.mean(axis=1)
96
  pk = np.abs(data).max()
97
  if pk > 0: data = data / pk
98
-
99
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
100
  sf.write(f.name, data, sr_in); tmp = f.name
101
-
102
  try:
103
- progress(0.05, desc=f"Extracting {stem_choice} stem ({demucs_model})...")
104
- stem_audio, stem_sr = extract_stem(tmp, stem=stem_choice, device="cpu",
105
  model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
106
-
107
- progress(0.15, desc="Detecting BPM...")
108
- bpm = detect_bpm(stem_audio, stem_sr)
109
-
110
- progress(0.25, desc="Detecting onsets...")
111
- hits = detect_onsets(stem_audio, stem_sr, mode=onset_mode,
112
- onset_delta=float(onset_delta), energy_threshold_db=float(energy_db),
113
- pre_pad=float(pre_pad), min_dur=float(min_dur),
114
- max_dur=float(max_dur), min_gap=float(min_gap))
115
  if not hits:
116
- return (audio_tuple(stem_audio, stem_sr),
117
- f"**BPM: {bpm}** — No hits found. Lower energy threshold or delta.",
118
- None, None, None, None, "", pd.DataFrame())
119
-
120
- progress(0.35, desc="Classifying...")
121
- hits = classify_hits(hits)
122
-
123
- progress(0.45, desc="Clustering...")
124
- clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold),
125
- max_compare_ms=float(ncc_compare_ms),
126
- target_min=int(target_min), target_max=int(target_max),
127
- linkage=str(linkage))
128
-
129
- progress(0.65, desc="Selecting best...")
130
- select_best(clusters)
131
-
132
  if do_synthesize:
133
- progress(0.7, desc="Synthesizing...")
134
- for c in clusters:
135
- if c.count >= 2: c.synthesized = synthesize_from_cluster(c)
136
-
137
- progress(0.75, desc="MIDI...")
138
- midi_path = tempfile.mktemp(suffix='.mid')
139
- export_midi(clusters, midi_path, bpm=bpm)
140
-
141
- progress(0.8, desc="Rendering...")
142
- rendered = render_midi_with_samples(clusters, sr=stem_sr)
143
-
144
- progress(0.85, desc="Packaging...")
145
- sd = tempfile.mkdtemp(); sp = []
146
- for c in sorted(clusters, key=lambda x: x.count, reverse=True):
147
- p = os.path.join(sd, f"{c.label}.wav"); c.best_hit.save(p); sp.append(p)
148
- zp = build_archive(clusters, bpm, stem_sr, midi_path=midi_path, rendered_audio=rendered)
149
-
150
- rows = []
151
- for c in sorted(clusters, key=lambda x: x.count, reverse=True):
152
- best = c.best_hit
153
- sc = sample_quality_score(best.audio, best.sr, c.label.rsplit('_',1)[0])
154
- rows.append({'Sample': c.label, 'Hits': c.count, 'MIDI': c.midi_note,
155
- 'Score': f"{sc['total']:.1f}", 'Clean': f"{sc['cleanness']:.2f}",
156
- 'Complete': f"{sc['completeness']:.2f}",
157
- 'Dur (ms)': f"{best.duration*1000:.0f}",
158
- 'First @': f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s"})
159
-
160
- sm = f"**BPM: {bpm}** · **{len(clusters)} samples** from {len(hits)} hits\n\n"
161
- sm += f"Model: `{demucs_model}` · Delta: `{onset_delta}` · Energy: `{energy_db}dB`\n\n"
162
- if int(target_min)>0 and int(target_max)>0:
163
- sm += f"Target: `{int(target_min)}–{int(target_max)}` clusters\n\n"
164
- sm += "| Sample | Hits | MIDI |\n|---|---|---|\n"
165
- for c in sorted(clusters, key=lambda x: x.count, reverse=True):
166
- sm += f"| {c.label} | {c.count} | {c.midi_note} |\n"
167
-
168
- progress(1.0, desc="Done!")
169
- return (audio_tuple(stem_audio, stem_sr), sm, audio_tuple(rendered, stem_sr),
170
- sp, midi_path, zp, "", pd.DataFrame(rows))
171
- finally:
172
- os.unlink(tmp)
173
-
174
-
175
- # ─── Tab 2: Evaluate ─────────────────────────────────────────────────────────
176
 
177
  def run_eval(pattern, bpm, bars, ncc_threshold, target_min, target_max, progress=gr.Progress()):
178
- progress(0.0, desc="Generating...")
179
- song = generate_test_song(pattern_name=pattern, bars=int(bars),
180
- bpm=float(bpm), variation='medium', seed=42)
181
- dbpm = detect_bpm(song.drums_only, song.sr)
182
- progress(0.2, desc="Extracting...")
183
- hits = detect_onsets(song.drums_only, song.sr)
184
- if not hits: return None, None, None, None, "", ""
185
- hits = classify_hits(hits)
186
- cl = cluster_hits(hits, ncc_threshold=float(ncc_threshold),
187
- target_min=int(target_min), target_max=int(target_max))
188
  select_best(cl)
189
  for c in cl:
190
  if c.count>=2: c.synthesized=synthesize_from_cluster(c)
191
- progress(0.5, desc="Rendering...")
192
- rendered = render_midi_with_samples(cl, sr=song.sr)
193
- progress(0.6, desc="Evaluating...")
194
- gt = {n: s.audio for n, s in song.samples.items()}
195
- gh = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} for h in song.hits]
196
- r = evaluate_extraction(cl, gt, gh, song.sr, hits)
197
- s = [{'Metric':'BPM','Value':f"{dbpm}",'Target':f"{song.bpm}"},
198
- {'Metric':'Clusters','Value':str(len(cl)),'Target':str(len(gt))},
199
- {'Metric':'Score','Value':f"{r.overall_score:.1f}/100",'Target':'> 70'},
200
- {'Metric':'SI-SDR','Value':f"{r.mean_si_sdr:.1f} dB",'Target':'> 10'},
201
- {'Metric':'Env Corr','Value':f"{r.mean_env_corr:.3f}",'Target':'> 0.9'}]
202
- if r.unmatched_gt: s.append({'Metric':'⚠ Missed','Value':', '.join(r.unmatched_gt),'Target':'None'})
203
- m = [{'Cluster':m.cluster_label,'GT':m.gt_name,'SI-SDR':f"{m.si_sdr:.1f}",
204
- 'Score':f"{m.sample_score:.1f}"} for m in r.matches]
205
  progress(1.0)
206
- return (audio_tuple(song.mix,song.sr), audio_tuple(rendered,song.sr),
207
- pd.DataFrame(s), pd.DataFrame(m) if m else None, "", "")
208
-
209
 
210
- # ─── Tab 3+4: Optimize + Leaderboard ─────────────────────────────────────────
211
-
212
- def run_optimize(n_iters, config_name, author, save_hub, progress=gr.Progress()):
213
- logs = []; progress(0.0)
214
- state = run_optimization(n_iterations=int(n_iters), config_name=config_name or "opt",
215
- author=author or "anon", save_to_hub=bool(save_hub), log_fn=lambda m: logs.append(m))
216
  progress(1.0)
217
- h = [{'Iter':r.iteration,'Score':f"{r.avg_score:.1f}",'Time':f"{r.duration_s:.1f}s"} for r in state.history]
218
  if state.history:
219
- fig,ax=plt.subplots(figsize=(10,4))
220
- ax.plot([r.iteration for r in state.history],[r.avg_score for r in state.history],'b-o')
221
- ax.set_xlabel('Iter'); ax.set_ylabel('Score'); ax.grid(True,alpha=0.3); plt.tight_layout()
222
  else: fig,ax=plt.subplots(); ax.text(0.5,0.5,"No data")
223
- return '\n'.join(logs), pd.DataFrame(h), fig, json.dumps(state.best_config, indent=2)
224
 
225
  def refresh_lb():
226
  try:
227
- lb = get_leaderboard()
228
- return pd.DataFrame(lb) if lb else pd.DataFrame(), ""
229
- except Exception as e: return pd.DataFrame(), str(e)
230
 
231
 
232
- # ─── Build App ────────────────────────────────────────────────────────────────
233
 
234
  def build_app():
235
- with gr.Blocks(title="🎵 Sample Extractor", theme=gr.themes.Soft(),
236
- css=".gradio-container{max-width:1300px!important}") as app:
237
- gr.Markdown("# 🎵 Sample Extractor v7\n"
238
- "Extract distinct sounds from audio. **Auto-Tune** finds the best parameters "
239
- "for your specific audio by measuring reconstruction quality.")
240
 
241
  with gr.Tabs():
242
  with gr.Tab("🎵 Extract"):
@@ -244,87 +208,94 @@ def build_app():
244
 
245
  with gr.Accordion("🔧 Stem Separation", open=False):
246
  with gr.Row():
247
- dm = gr.Dropdown(DEMUCS_MODELS, value="htdemucs_ft", label="Model")
248
- st = gr.Dropdown(['drums','bass','other','vocals','all'], value='drums', label='Stem')
249
- dsh = gr.Slider(0,5,value=1,step=1,label='Shifts')
250
- dov = gr.Slider(0.0,0.5,value=0.25,step=0.05,label='Overlap')
251
 
252
  with gr.Accordion("🎯 Onset Detection", open=False):
253
  with gr.Row():
254
- om = gr.Dropdown(['auto','percussive','harmonic','broadband'],value='auto',label='Mode')
255
- od = gr.Slider(0.01,0.5,value=0.12,step=0.01,label='Delta')
256
- ed = gr.Slider(-70,-10,value=-35,step=1,label='Energy (dB)')
 
 
 
 
 
 
 
257
  with gr.Row():
258
- pp = gr.Slider(0.0,0.05,value=0.005,step=0.001,label='Pre-pad (s)')
259
- mnd = gr.Slider(0.005,0.2,value=0.02,step=0.005,label='Min dur (s)')
260
- mxd = gr.Slider(0.1,5.0,value=1.5,step=0.1,label='Max dur (s)')
261
- mg = gr.Slider(0.005,0.2,value=0.03,step=0.005,label='Min gap (s)')
262
 
263
  with gr.Accordion("🔗 Clustering", open=True):
264
- gr.Markdown("**Target cluster range** — or set both to 0 for manual threshold:")
265
- with gr.Row():
266
- tmin = gr.Number(value=5, label='Target min clusters', precision=0)
267
- tmax = gr.Number(value=20, label='Target max clusters', precision=0)
 
268
  with gr.Row():
269
- nt = gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC threshold')
270
- nms = gr.Slider(0,1000,value=0,step=50,label='Compare ms (0=auto)')
271
- lnk = gr.Dropdown(['average','complete','single'],value='average',label='Linkage')
272
 
273
  with gr.Accordion("⚙️ Post-processing", open=False):
274
- syn = gr.Checkbox(value=True, label='Synthesize optimal samples')
275
 
276
  with gr.Row():
277
- tune_btn = gr.Button("🎛️ Auto-Tune Parameters", variant="secondary", size="lg")
278
- extract_btn = gr.Button("🔬 Extract Samples", variant="primary", size="lg")
279
 
280
- tune_summary = gr.Markdown("")
281
- tune_log = gr.Textbox(label="Auto-tune log", lines=8, max_lines=15, visible=False)
282
 
283
- summary_md = gr.Markdown("*Upload audio → click Auto-Tune or Extract*")
284
  with gr.Row():
285
- stem_out = gr.Audio(type='numpy', label='Stem', interactive=False)
286
- rend_out = gr.Audio(type='numpy', label='🔊 Reconstruction', interactive=False)
287
  gr.Markdown("### Downloads")
288
  with gr.Row():
289
- arc = gr.File(label="📦 ZIP", interactive=False)
290
- mid = gr.File(label="🎹 MIDI", interactive=False)
291
- smp = gr.File(label="WAV samples", file_count="multiple", interactive=False)
292
- met = gr.Dataframe(label="Samples")
293
- stx = gr.Textbox(visible=False)
294
 
295
- dm.change(fn=lambda m: gr.update(choices=DEMUCS_STEMS.get(m,["drums","bass","other","vocals"])+["all"]),
296
- inputs=[dm], outputs=[st])
297
 
298
- # Auto-tune updates the onset/clustering sliders
299
  tune_btn.click(run_auto_tune,
300
- [audio_in, st, dm, dsh, dov, om],
 
 
301
  [od, ed, mg, tmin, tmax, tune_summary, tune_log])
302
 
303
  extract_btn.click(run_extraction,
304
- [audio_in, st, dm, dsh, dov, om, od, ed, pp, mnd, mxd, mg,
305
- nt, nms, lnk, tmin, tmax, syn],
306
- [stem_out, summary_md, rend_out, smp, mid, arc, stx, met])
307
 
308
  with gr.Tab("📊 Evaluate"):
309
- gr.Markdown("Synthetic ground truth evaluation.")
310
  with gr.Row():
311
- ep = gr.Dropdown(['rock','funk','halftime'],value='rock',label='Pattern')
312
- eb = gr.Slider(80,200,value=120,step=2,label='BPM')
313
- ebs = gr.Slider(2,8,value=4,step=1,label='Bars')
314
  with gr.Row():
315
- en = gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC')
316
- etm = gr.Number(value=0,label='Target min',precision=0)
317
- etx = gr.Number(value=0,label='Target max',precision=0)
318
- evb = gr.Button("🧪 Evaluate", variant="primary", size="lg")
319
  with gr.Row():
320
- evm = gr.Audio(type='numpy',label='Original',interactive=False)
321
- evr = gr.Audio(type='numpy',label='Reconstruction',interactive=False)
322
- evs = gr.Dataframe(label="Summary"); evm2 = gr.Dataframe(label="Matches")
323
- es1 = gr.Textbox(visible=False); es2 = gr.Textbox(visible=False)
324
  evb.click(run_eval,[ep,eb,ebs,en,etm,etx],[evm,evr,evs,evm2,es1,es2])
325
 
326
  with gr.Tab("🔄 Optimize"):
327
- gr.Markdown("### Synthetic test optimization\nTests across 6 songs.")
328
  with gr.Row():
329
  on=gr.Slider(2,30,value=5,step=1,label='Iters')
330
  ocn=gr.Textbox(value="opt",label='Name')
@@ -332,14 +303,13 @@ def build_app():
332
  osv=gr.Checkbox(value=True,label='Save')
333
  ob=gr.Button("🚀 Run",variant="primary",size="lg")
334
  ol=gr.Textbox(label="Log",lines=20,max_lines=40)
335
- oh=gr.Dataframe(label="History"); op=gr.Plot(label="Progress")
336
  oc=gr.Code(label="Config",language="json")
337
  ob.click(run_optimize,[on,ocn,oa,osv],[ol,oh,op,oc])
338
 
339
  with gr.Tab("🏆 Leaderboard"):
340
- gr.Markdown("### Configs ranked by score")
341
- lb=gr.Button("🔄 Refresh"); lt=gr.Dataframe(); ls=gr.Textbox(visible=False)
342
- lb.click(refresh_lb,[],[lt,ls])
343
 
344
  return app
345
 
 
1
  """
2
+ Gradio UI — Sample Extractor v8.
3
+ Auto-tune with parameter locking.
4
  """
5
 
6
  import gradio as gr
 
23
  from config_store import PipelineConfig, get_leaderboard
24
  from optimizer_v2 import run_optimization
25
 
 
26
  def audio_tuple(a, sr):
27
+ a = a.astype(np.float32); pk = np.abs(a).max()
 
28
  if pk > 0: a = a / pk * 0.95
29
  return (sr, a)
30
 
31
 
32
+ # ─── Auto-tune with locks ────────────────────────────────────────────────────
33
 
34
  def run_auto_tune(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
35
+ onset_mode,
36
+ # Current values (used when locked)
37
+ cur_delta, cur_energy, cur_gap, cur_tmin, cur_tmax,
38
+ # Lock flags
39
+ lock_delta, lock_energy, lock_gap, lock_targets,
40
+ progress=gr.Progress()):
41
  if audio_in is None:
42
+ return [gr.update()] * 5 + ["Upload audio first", ""]
43
+
44
+ # Build locks dict from checkboxes
45
+ locks = {}
46
+ if lock_delta: locks['onset_delta'] = float(cur_delta)
47
+ if lock_energy: locks['energy_threshold_db'] = float(cur_energy)
48
+ if lock_gap: locks['min_gap'] = float(cur_gap)
49
+ if lock_targets:
50
+ locks['target_min'] = int(cur_tmin)
51
+ locks['target_max'] = int(cur_tmax)
52
 
53
  progress(0.0, desc="Loading audio...")
54
  sr_in, data = audio_in
 
65
  stem_audio, stem_sr = extract_stem(tmp, stem=stem_choice, device="cpu",
66
  model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
67
 
68
+ lock_desc = ', '.join(f'{k}={v}' for k, v in locks.items()) if locks else 'none'
69
+ progress(0.15, desc=f"Auto-tuning (locked: {lock_desc})...")
70
  best_params, best_score, log_lines = auto_tune(
71
+ stem_audio, stem_sr, mode=onset_mode, locks=locks)
 
72
 
73
+ progress(1.0, desc=f"Score: {best_score:.1f}")
74
 
75
+ log_text = '\n'.join(log_lines[-30:])
76
+ lock_info = f"🔒 Locked: {lock_desc}" if locks else "No locks — all params tuned freely"
77
+ summary = (f"**Auto-tune complete!** Score: **{best_score:.1f}/100**\n\n"
78
+ f"{lock_info}\n\n"
79
+ f"Click **Extract Samples** to run with these settings.")
80
 
81
+ # Return updated values — only update unlocked params
82
  return [
83
+ gr.update(value=best_params['onset_delta']) if not lock_delta else gr.update(),
84
+ gr.update(value=best_params['energy_threshold_db']) if not lock_energy else gr.update(),
85
+ gr.update(value=best_params['min_gap']) if not lock_gap else gr.update(),
86
+ gr.update(value=best_params.get('target_min', 5)) if not lock_targets else gr.update(),
87
+ gr.update(value=best_params.get('target_max', 20)) if not lock_targets else gr.update(),
88
  summary,
89
  log_text,
90
  ]
 
92
  os.unlink(tmp)
93
 
94
 
95
+ # ─── Extract ─────────────────────────────────────────────────────���────────────
96
 
97
  def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
98
  onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap,
99
  ncc_threshold, ncc_compare_ms, linkage, target_min, target_max,
100
  do_synthesize, progress=gr.Progress()):
101
+ if audio_in is None: return [None]*8
102
+ progress(0.0, desc="Loading...")
103
+ sr_in, data = audio_in; data = data.astype(np.float32)
 
 
 
104
  if data.ndim > 1: data = data.mean(axis=1)
105
  pk = np.abs(data).max()
106
  if pk > 0: data = data / pk
 
107
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
108
  sf.write(f.name, data, sr_in); tmp = f.name
 
109
  try:
110
+ progress(0.05, desc=f"Stem ({demucs_model})...")
111
+ sa, ssr = extract_stem(tmp, stem=stem_choice, device="cpu",
112
  model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
113
+ progress(0.15, desc="BPM..."); bpm = detect_bpm(sa, ssr)
114
+ progress(0.25, desc="Onsets...")
115
+ hits = detect_onsets(sa, ssr, mode=onset_mode, onset_delta=float(onset_delta),
116
+ energy_threshold_db=float(energy_db), pre_pad=float(pre_pad),
117
+ min_dur=float(min_dur), max_dur=float(max_dur), min_gap=float(min_gap))
 
 
 
 
118
  if not hits:
119
+ return (audio_tuple(sa,ssr), f"**BPM: {bpm}** — No hits.", None,None,None,None,"",pd.DataFrame())
120
+ progress(0.35, desc="Classify..."); hits = classify_hits(hits)
121
+ progress(0.45, desc="Cluster...")
122
+ cl = cluster_hits(hits, ncc_threshold=float(ncc_threshold), max_compare_ms=float(ncc_compare_ms),
123
+ target_min=int(target_min), target_max=int(target_max), linkage=str(linkage))
124
+ progress(0.65, desc="Select..."); select_best(cl)
 
 
 
 
 
 
 
 
 
 
125
  if do_synthesize:
126
+ progress(0.7, desc="Synth...")
127
+ for c in cl:
128
+ if c.count>=2: c.synthesized=synthesize_from_cluster(c)
129
+ progress(0.75, desc="MIDI..."); mp=tempfile.mktemp(suffix='.mid'); export_midi(cl,mp,bpm=bpm)
130
+ progress(0.8, desc="Render..."); rend=render_midi_with_samples(cl,sr=ssr)
131
+ progress(0.85, desc="Package...")
132
+ sd=tempfile.mkdtemp(); sp=[]
133
+ for c in sorted(cl,key=lambda x:x.count,reverse=True):
134
+ p=os.path.join(sd,f"{c.label}.wav"); c.best_hit.save(p); sp.append(p)
135
+ zp=build_archive(cl,bpm,ssr,midi_path=mp,rendered_audio=rend)
136
+ rows=[]
137
+ for c in sorted(cl,key=lambda x:x.count,reverse=True):
138
+ b=c.best_hit; sc=sample_quality_score(b.audio,b.sr,c.label.rsplit('_',1)[0])
139
+ rows.append({'Sample':c.label,'Hits':c.count,'MIDI':c.midi_note,
140
+ 'Score':f"{sc['total']:.1f}",'Clean':f"{sc['cleanness']:.2f}",
141
+ 'Complete':f"{sc['completeness']:.2f}",
142
+ 'Dur':f"{b.duration*1000:.0f}ms",
143
+ 'First':f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s"})
144
+ sm=f"**BPM: {bpm}** · **{len(cl)} samples** from {len(hits)} hits\n\n"
145
+ sm+=f"`{demucs_model}` · δ=`{onset_delta}` · E=`{energy_db}dB`"
146
+ if int(target_min)>0 and int(target_max)>0: sm+=f" · clusters `{int(target_min)}–{int(target_max)}`"
147
+ sm+="\n\n| Sample | Hits | MIDI |\n|---|---|---|\n"
148
+ for c in sorted(cl,key=lambda x:x.count,reverse=True): sm+=f"| {c.label} | {c.count} | {c.midi_note} |\n"
149
+ progress(1.0)
150
+ return (audio_tuple(sa,ssr),sm,audio_tuple(rend,ssr),sp,mp,zp,"",pd.DataFrame(rows))
151
+ finally: os.unlink(tmp)
152
+
153
+
154
+ # ─── Evaluate ─────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def run_eval(pattern, bpm, bars, ncc_threshold, target_min, target_max, progress=gr.Progress()):
157
+ progress(0.0); song=generate_test_song(pattern_name=pattern,bars=int(bars),bpm=float(bpm),variation='medium',seed=42)
158
+ dbpm=detect_bpm(song.drums_only,song.sr); progress(0.2)
159
+ hits=detect_onsets(song.drums_only,song.sr)
160
+ if not hits: return None,None,None,None,"",""
161
+ hits=classify_hits(hits)
162
+ cl=cluster_hits(hits,ncc_threshold=float(ncc_threshold),target_min=int(target_min),target_max=int(target_max))
 
 
 
 
163
  select_best(cl)
164
  for c in cl:
165
  if c.count>=2: c.synthesized=synthesize_from_cluster(c)
166
+ progress(0.5); rend=render_midi_with_samples(cl,sr=song.sr); progress(0.6)
167
+ gt={n:s.audio for n,s in song.samples.items()}
168
+ gh=[{'sample':h.sample_name,'onset':h.onset_time,'velocity':h.velocity} for h in song.hits]
169
+ r=evaluate_extraction(cl,gt,gh,song.sr,hits)
170
+ s=[{'Metric':'BPM','Value':f"{dbpm}",'Target':f"{song.bpm}"},
171
+ {'Metric':'Clusters','Value':str(len(cl)),'Target':str(len(gt))},
172
+ {'Metric':'Score','Value':f"{r.overall_score:.1f}/100",'Target':'> 70'}]
173
+ if r.unmatched_gt: s.append({'Metric':'','Value':', '.join(r.unmatched_gt),'Target':'None'})
174
+ m=[{'Cluster':m.cluster_label,'GT':m.gt_name,'Score':f"{m.sample_score:.1f}"} for m in r.matches]
 
 
 
 
 
175
  progress(1.0)
176
+ return (audio_tuple(song.mix,song.sr),audio_tuple(rend,song.sr),pd.DataFrame(s),pd.DataFrame(m) if m else None,"","")
 
 
177
 
178
+ def run_optimize(n_iters,config_name,author,save_hub,progress=gr.Progress()):
179
+ logs=[]; progress(0.0)
180
+ state=run_optimization(n_iterations=int(n_iters),config_name=config_name or "opt",
181
+ author=author or "anon",save_to_hub=bool(save_hub),log_fn=lambda m:logs.append(m))
 
 
182
  progress(1.0)
183
+ h=[{'Iter':r.iteration,'Score':f"{r.avg_score:.1f}"} for r in state.history]
184
  if state.history:
185
+ fig,ax=plt.subplots(figsize=(10,4)); ax.plot([r.iteration for r in state.history],[r.avg_score for r in state.history],'b-o')
186
+ ax.grid(True,alpha=0.3); plt.tight_layout()
 
187
  else: fig,ax=plt.subplots(); ax.text(0.5,0.5,"No data")
188
+ return '\n'.join(logs),pd.DataFrame(h),fig,json.dumps(state.best_config,indent=2)
189
 
190
  def refresh_lb():
191
  try:
192
+ lb=get_leaderboard(); return pd.DataFrame(lb) if lb else pd.DataFrame(),""
193
+ except Exception as e: return pd.DataFrame(),str(e)
 
194
 
195
 
196
+ # ─── App ──────────────────────────────────────────────────────────────────────
197
 
198
  def build_app():
199
+ with gr.Blocks(title="🎵 Sample Extractor",theme=gr.themes.Soft(),
200
+ css=".gradio-container{max-width:1300px!important} .lock-row{align-items:center}") as app:
201
+ gr.Markdown("# 🎵 Sample Extractor v8\n"
202
+ "**Auto-Tune** finds optimal parameters for your audio. "
203
+ "🔒 **Lock** any parameter to constrain the search.")
204
 
205
  with gr.Tabs():
206
  with gr.Tab("🎵 Extract"):
 
208
 
209
  with gr.Accordion("🔧 Stem Separation", open=False):
210
  with gr.Row():
211
+ dm=gr.Dropdown(DEMUCS_MODELS,value="htdemucs_ft",label="Model")
212
+ st=gr.Dropdown(['drums','bass','other','vocals','all'],value='drums',label='Stem')
213
+ dsh=gr.Slider(0,5,value=1,step=1,label='Shifts')
214
+ dov=gr.Slider(0.0,0.5,value=0.25,step=0.05,label='Overlap')
215
 
216
  with gr.Accordion("🎯 Onset Detection", open=False):
217
  with gr.Row():
218
+ om=gr.Dropdown(['auto','percussive','harmonic','broadband'],value='auto',label='Mode')
219
+ with gr.Row(elem_classes="lock-row"):
220
+ od=gr.Slider(0.01,0.5,value=0.12,step=0.01,label='Delta')
221
+ lock_od=gr.Checkbox(value=False,label='🔒',scale=0)
222
+ with gr.Row(elem_classes="lock-row"):
223
+ ed=gr.Slider(-70,-10,value=-35,step=1,label='Energy (dB)')
224
+ lock_ed=gr.Checkbox(value=False,label='🔒',scale=0)
225
+ with gr.Row(elem_classes="lock-row"):
226
+ mg=gr.Slider(0.005,0.2,value=0.03,step=0.005,label='Min gap (s)')
227
+ lock_mg=gr.Checkbox(value=False,label='🔒',scale=0)
228
  with gr.Row():
229
+ pp=gr.Slider(0.0,0.05,value=0.005,step=0.001,label='Pre-pad (s)')
230
+ mnd=gr.Slider(0.005,0.2,value=0.02,step=0.005,label='Min dur (s)')
231
+ mxd=gr.Slider(0.1,5.0,value=1.5,step=0.1,label='Max dur (s)')
 
232
 
233
  with gr.Accordion("🔗 Clustering", open=True):
234
+ with gr.Row(elem_classes="lock-row"):
235
+ tmin=gr.Number(value=5,label='Target min clusters',precision=0)
236
+ tmax=gr.Number(value=20,label='Target max clusters',precision=0)
237
+ lock_tgt=gr.Checkbox(value=True,label='🔒 Lock range',scale=0)
238
+ gr.Markdown("*🔒 = auto-tune will respect this value. Unchecked = auto-tune will change it.*")
239
  with gr.Row():
240
+ nt=gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC threshold')
241
+ nms=gr.Slider(0,1000,value=0,step=50,label='Compare ms (0=auto)')
242
+ lnk=gr.Dropdown(['average','complete','single'],value='average',label='Linkage')
243
 
244
  with gr.Accordion("⚙️ Post-processing", open=False):
245
+ syn=gr.Checkbox(value=True,label='Synthesize optimal samples')
246
 
247
  with gr.Row():
248
+ tune_btn=gr.Button("🎛️ Auto-Tune",variant="secondary",size="lg")
249
+ extract_btn=gr.Button("🔬 Extract Samples",variant="primary",size="lg")
250
 
251
+ tune_summary=gr.Markdown("")
252
+ tune_log=gr.Textbox(label="Auto-tune log",lines=8,max_lines=15,visible=False)
253
 
254
+ summary_md=gr.Markdown("*Upload audio → Auto-Tune or Extract*")
255
  with gr.Row():
256
+ stem_out=gr.Audio(type='numpy',label='Stem',interactive=False)
257
+ rend_out=gr.Audio(type='numpy',label='🔊 Reconstruction',interactive=False)
258
  gr.Markdown("### Downloads")
259
  with gr.Row():
260
+ arc=gr.File(label="📦 ZIP",interactive=False)
261
+ mid=gr.File(label="🎹 MIDI",interactive=False)
262
+ smp=gr.File(label="WAV samples",file_count="multiple",interactive=False)
263
+ met=gr.Dataframe(label="Samples")
264
+ stx=gr.Textbox(visible=False)
265
 
266
+ dm.change(fn=lambda m:gr.update(choices=DEMUCS_STEMS.get(m,["drums","bass","other","vocals"])+["all"]),
267
+ inputs=[dm],outputs=[st])
268
 
 
269
  tune_btn.click(run_auto_tune,
270
+ [audio_in, st, dm, dsh, dov, om,
271
+ od, ed, mg, tmin, tmax, # current values
272
+ lock_od, lock_ed, lock_mg, lock_tgt], # lock flags
273
  [od, ed, mg, tmin, tmax, tune_summary, tune_log])
274
 
275
  extract_btn.click(run_extraction,
276
+ [audio_in,st,dm,dsh,dov,om,od,ed,pp,mnd,mxd,mg,nt,nms,lnk,tmin,tmax,syn],
277
+ [stem_out,summary_md,rend_out,smp,mid,arc,stx,met])
 
278
 
279
  with gr.Tab("📊 Evaluate"):
280
+ gr.Markdown("Synthetic evaluation.")
281
  with gr.Row():
282
+ ep=gr.Dropdown(['rock','funk','halftime'],value='rock',label='Pattern')
283
+ eb=gr.Slider(80,200,value=120,step=2,label='BPM')
284
+ ebs=gr.Slider(2,8,value=4,step=1,label='Bars')
285
  with gr.Row():
286
+ en=gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC')
287
+ etm=gr.Number(value=0,label='Min',precision=0)
288
+ etx=gr.Number(value=0,label='Max',precision=0)
289
+ evb=gr.Button("🧪 Evaluate",variant="primary",size="lg")
290
  with gr.Row():
291
+ evm=gr.Audio(type='numpy',label='Original',interactive=False)
292
+ evr=gr.Audio(type='numpy',label='Reconstruction',interactive=False)
293
+ evs=gr.Dataframe(label="Summary"); evm2=gr.Dataframe(label="Matches")
294
+ es1=gr.Textbox(visible=False); es2=gr.Textbox(visible=False)
295
  evb.click(run_eval,[ep,eb,ebs,en,etm,etx],[evm,evr,evs,evm2,es1,es2])
296
 
297
  with gr.Tab("🔄 Optimize"):
298
+ gr.Markdown("### Synthetic optimization")
299
  with gr.Row():
300
  on=gr.Slider(2,30,value=5,step=1,label='Iters')
301
  ocn=gr.Textbox(value="opt",label='Name')
 
303
  osv=gr.Checkbox(value=True,label='Save')
304
  ob=gr.Button("🚀 Run",variant="primary",size="lg")
305
  ol=gr.Textbox(label="Log",lines=20,max_lines=40)
306
+ oh=gr.Dataframe(label="History"); op=gr.Plot()
307
  oc=gr.Code(label="Config",language="json")
308
  ob.click(run_optimize,[on,ocn,oa,osv],[ol,oh,op,oc])
309
 
310
  with gr.Tab("🏆 Leaderboard"):
311
+ lbb=gr.Button("🔄 Refresh"); lt=gr.Dataframe(); ls=gr.Textbox(visible=False)
312
+ lbb.click(refresh_lb,[],[lt,ls])
 
313
 
314
  return app
315