rikhoffbauer2's picture
v5: target cluster range + caching + average linkage β€” app_v2.py
1bdf6ad verified
"""
Gradio UI β€” Sample Extractor v4.
NCC clustering, full parameter control, Demucs model selection.
"""
import gradio as gr
import numpy as np, pandas as pd, json, sys, os, tempfile
import soundfile as sf, librosa
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from sample_extractor import (
extract_stem, detect_onsets, classify_hits,
cluster_hits, select_best, synthesize_from_cluster,
sample_quality_score, export_midi, detect_bpm,
render_midi_with_samples, build_archive, cache_clear,
DEMUCS_MODELS, DEMUCS_STEMS,
)
from synth_generator import generate_test_song
from evaluation import evaluate_extraction
from config_store import PipelineConfig, get_leaderboard
from optimizer_v2 import run_optimization
def audio_tuple(a, sr):
a = a.astype(np.float32)
pk = np.abs(a).max()
if pk > 0: a = a / pk * 0.95
return (sr, a)
# ─── Tab 1: Extract ──────────────────────────────────────────────────────────
def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap,
ncc_threshold, ncc_compare_ms, linkage, target_min, target_max,
do_synthesize, progress=gr.Progress()):
if audio_in is None:
return [None] * 8
progress(0.0, desc="Loading audio...")
sr_in, data = audio_in
data = data.astype(np.float32)
if data.ndim > 1: data = data.mean(axis=1)
pk = np.abs(data).max()
if pk > 0: data = data / pk
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
sf.write(f.name, data, sr_in); tmp = f.name
try:
progress(0.05, desc=f"Extracting {stem_choice} stem ({demucs_model})...")
stem_audio, stem_sr = extract_stem(
tmp, stem=stem_choice, device="cpu",
model_name=demucs_model, shifts=int(demucs_shifts),
overlap=float(demucs_overlap))
progress(0.15, desc="Detecting BPM...")
bpm = detect_bpm(stem_audio, stem_sr)
progress(0.25, desc="Detecting onsets...")
hits = detect_onsets(stem_audio, stem_sr, mode=onset_mode,
onset_delta=float(onset_delta),
energy_threshold_db=float(energy_db),
pre_pad=float(pre_pad), min_dur=float(min_dur),
max_dur=float(max_dur), min_gap=float(min_gap))
if not hits:
return (audio_tuple(stem_audio, stem_sr),
f"**Detected BPM: {bpm}**\n\nNo hits found β€” try lowering energy threshold.",
None, None, None, None, "", pd.DataFrame())
progress(0.35, desc="Classifying...")
hits = classify_hits(hits)
progress(0.45, desc=f"NCC clustering...")
clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold),
max_compare_ms=float(ncc_compare_ms),
target_min=int(target_min), target_max=int(target_max),
linkage=str(linkage))
progress(0.65, desc="Scoring & selecting best...")
select_best(clusters)
if do_synthesize:
progress(0.7, desc="Synthesizing...")
for c in clusters:
if c.count >= 2: c.synthesized = synthesize_from_cluster(c)
progress(0.75, desc="Building MIDI...")
midi_path = tempfile.mktemp(suffix='.mid')
export_midi(clusters, midi_path, bpm=bpm)
progress(0.8, desc="Rendering reconstruction...")
rendered = render_midi_with_samples(clusters, sr=stem_sr)
progress(0.85, desc="Packaging...")
sample_dir = tempfile.mkdtemp()
sample_paths = []
for c in sorted(clusters, key=lambda x: x.count, reverse=True):
sp = os.path.join(sample_dir, f"{c.label}.wav")
c.best_hit.save(sp); sample_paths.append(sp)
zip_path = build_archive(clusters, bpm, stem_sr,
midi_path=midi_path, rendered_audio=rendered)
# Metrics
rows = []
for c in sorted(clusters, key=lambda x: x.count, reverse=True):
best = c.best_hit
sc = sample_quality_score(best.audio, best.sr, c.label.rsplit('_',1)[0])
rows.append({
'Sample': c.label, 'Hits': c.count, 'MIDI': c.midi_note,
'Score': f"{sc['total']:.1f}",
'Clean': f"{sc['cleanness']:.2f}",
'Complete': f"{sc['completeness']:.2f}",
'Dur (ms)': f"{best.duration*1000:.0f}",
'First @': f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s",
})
summary = f"**Detected BPM: {bpm}** Β· **{len(clusters)} unique samples** from {len(hits)} hits\n\n"
summary += f"Model: `{demucs_model}` Β· NCC threshold: `{ncc_threshold}` Β· Onset delta: `{onset_delta}`\n\n"
summary += "| Sample | Hits | MIDI Note |\n|---|---|---|\n"
for c in sorted(clusters, key=lambda x: x.count, reverse=True):
summary += f"| {c.label} | {c.count} | {c.midi_note} |\n"
progress(1.0, desc="Done!")
return (audio_tuple(stem_audio, stem_sr), summary,
audio_tuple(rendered, stem_sr), sample_paths,
midi_path, zip_path, "", pd.DataFrame(rows))
finally:
os.unlink(tmp)
# ─── Tab 2: Evaluate ─────────────────────────────────────────────────────────
def run_eval(pattern, bpm, bars, ncc_threshold, progress=gr.Progress()):
progress(0.0, desc="Generating synthetic song...")
song = generate_test_song(pattern_name=pattern, bars=int(bars),
bpm=float(bpm), variation='medium', seed=42)
detected_bpm = detect_bpm(song.drums_only, song.sr)
progress(0.2, desc="Extracting...")
hits = detect_onsets(song.drums_only, song.sr)
if not hits: return None, None, None, None, "", ""
hits = classify_hits(hits)
clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold))
select_best(clusters)
for c in clusters:
if c.count >= 2: c.synthesized = synthesize_from_cluster(c)
progress(0.5, desc="Rendering...")
rendered = render_midi_with_samples(clusters, sr=song.sr)
progress(0.6, desc="Evaluating...")
gt = {n: s.audio for n, s in song.samples.items()}
gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity}
for h in song.hits]
report = evaluate_extraction(clusters, gt, gt_hits, song.sr, hits)
mix_out = audio_tuple(song.mix, song.sr)
rendered_out = audio_tuple(rendered, song.sr)
summary = [
{'Metric': 'Detected BPM', 'Value': f"{detected_bpm}", 'Target': f"{song.bpm}"},
{'Metric': 'Clusters', 'Value': str(len(clusters)), 'Target': str(len(gt))},
{'Metric': 'Overall', 'Value': f"{report.overall_score:.1f}/100", 'Target': '> 70'},
{'Metric': 'SI-SDR', 'Value': f"{report.mean_si_sdr:.1f} dB", 'Target': '> 10'},
{'Metric': 'Env Corr', 'Value': f"{report.mean_env_corr:.3f}", 'Target': '> 0.9'},
]
if report.unmatched_gt:
summary.append({'Metric': '⚠ Unmatched', 'Value': ', '.join(report.unmatched_gt), 'Target': 'None'})
matches = [{'Cluster': m.cluster_label, 'GT': m.gt_name, 'SI-SDR': f"{m.si_sdr:.1f}",
'Score': f"{m.sample_score:.1f}"} for m in report.matches]
progress(1.0)
return (mix_out, rendered_out, pd.DataFrame(summary),
pd.DataFrame(matches) if matches else None, "", "")
# ─── Tab 3: Optimize ─────────────────────────────────────────────────────────
def run_optimize(n_iters, config_name, author, save_hub, progress=gr.Progress()):
logs = []
progress(0.0, desc="Starting optimization...")
state = run_optimization(n_iterations=int(n_iters),
config_name=config_name or "optimized",
author=author or "anonymous",
save_to_hub=bool(save_hub), log_fn=lambda m: logs.append(m))
progress(1.0)
hist = [{'Iter': r.iteration, 'Score': f"{r.avg_score:.1f}",
'Time': f"{r.duration_s:.1f}s"} for r in state.history]
if state.history:
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot([r.iteration for r in state.history], [r.avg_score for r in state.history], 'b-o')
ax.set_xlabel('Iteration'); ax.set_ylabel('Score'); ax.grid(True, alpha=0.3); plt.tight_layout()
else:
fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No data")
return '\n'.join(logs), pd.DataFrame(hist), fig, json.dumps(state.best_config, indent=2)
# ─── Tab 4: Leaderboard ──────────────────────────────────────────────────────
def refresh_leaderboard():
try:
lb = get_leaderboard()
return pd.DataFrame(lb) if lb else pd.DataFrame(), ""
except Exception as e:
return pd.DataFrame(), str(e)
# ─── Build App ────────────────────────────────────────────────────────────────
def get_stems_for_model(model_name):
stems = DEMUCS_STEMS.get(model_name, ["drums", "bass", "other", "vocals"])
return gr.update(choices=stems + ["all"], value=stems[0])
def build_app():
with gr.Blocks(title="🎡 Sample Extractor", theme=gr.themes.Soft(),
css=".gradio-container{max-width:1300px!important}") as app:
gr.Markdown("# 🎡 Sample Extractor v4\n"
"Extract distinct sounds from audio using **NCC waveform matching** β€” "
"correctly groups identical samples regardless of velocity.\n"
"Full control over Demucs model, onset detection, and clustering parameters.")
with gr.Tabs():
# ── Extract ──
with gr.Tab("🎡 Extract"):
audio_in = gr.Audio(sources=['upload'], type='numpy', label='Upload Audio')
with gr.Accordion("πŸ”§ Stem Separation", open=False):
with gr.Row():
demucs_model = gr.Dropdown(DEMUCS_MODELS, value="htdemucs_ft",
label="Demucs Model")
stem_dd = gr.Dropdown(['drums','bass','other','vocals','all'],
value='drums', label='Stem')
demucs_shifts = gr.Slider(0, 5, value=1, step=1,
label='Shifts (TTA, 0=fastest)')
demucs_overlap = gr.Slider(0.0, 0.5, value=0.25, step=0.05,
label='Overlap')
with gr.Accordion("🎯 Onset Detection", open=False):
with gr.Row():
onset_mode = gr.Dropdown(['auto','percussive','harmonic','broadband'],
value='auto', label='Mode')
onset_delta = gr.Slider(0.01, 0.5, value=0.07, step=0.01,
label='Delta (sensitivity)')
energy_db = gr.Slider(-70, -10, value=-45, step=1,
label='Energy threshold (dB)')
with gr.Row():
pre_pad = gr.Slider(0.0, 0.05, value=0.005, step=0.001,
label='Pre-pad (s)')
min_dur = gr.Slider(0.005, 0.2, value=0.02, step=0.005,
label='Min duration (s)')
max_dur = gr.Slider(0.1, 5.0, value=1.5, step=0.1,
label='Max duration (s)')
min_gap = gr.Slider(0.005, 0.2, value=0.015, step=0.005,
label='Min gap (s)')
with gr.Accordion("πŸ”— Clustering", open=False):
with gr.Row():
ncc_thresh = gr.Slider(0.3, 0.99, value=0.80, step=0.01,
label='NCC threshold (higher = stricter)')
ncc_ms = gr.Slider(50, 1000, value=200, step=50,
label='Compare window (ms)')
linkage_dd = gr.Dropdown(['average', 'complete', 'single'],
value='average', label='Linkage')
with gr.Row():
target_min = gr.Number(value=0, label='Target min clusters (0 = use threshold)',
precision=0)
target_max = gr.Number(value=0, label='Target max clusters (0 = use threshold)',
precision=0)
gr.Markdown("*Set both target min/max > 0 to auto-search for the right threshold. "
"Leave at 0 to use the NCC threshold directly.*")
with gr.Accordion("βš™οΈ Post-processing", open=False):
do_synth = gr.Checkbox(value=True, label='Synthesize optimal samples from clusters')
extract_btn = gr.Button("πŸ”¬ Extract Samples", variant="primary", size="lg")
summary_md = gr.Markdown("*Upload audio and click Extract*")
with gr.Row():
stem_out = gr.Audio(type='numpy', label='Stem', interactive=False)
rendered_out = gr.Audio(type='numpy', label='πŸ”Š Reconstruction', interactive=False)
gr.Markdown("### Downloads")
with gr.Row():
archive_file = gr.File(label="πŸ“¦ ZIP Archive", interactive=False)
midi_file = gr.File(label="🎹 MIDI", interactive=False)
sample_files = gr.File(label="Individual WAV samples", file_count="multiple",
interactive=False)
metrics_tbl = gr.Dataframe(label="Extracted Samples")
status_txt = gr.Textbox(visible=False)
# Update available stems when model changes
demucs_model.change(
fn=lambda m: gr.update(choices=DEMUCS_STEMS.get(m, ["drums","bass","other","vocals"]) + ["all"]),
inputs=[demucs_model], outputs=[stem_dd])
extract_btn.click(
run_extraction,
[audio_in, stem_dd, demucs_model, demucs_shifts, demucs_overlap,
onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap,
ncc_thresh, ncc_ms, linkage_dd, target_min, target_max, do_synth],
[stem_out, summary_md, rendered_out, sample_files,
midi_file, archive_file, status_txt, metrics_tbl])
# ── Evaluate ──
with gr.Tab("πŸ“Š Evaluate"):
gr.Markdown("Generate synthetic song β†’ extract β†’ compare to ground truth.")
with gr.Row():
ev_pat = gr.Dropdown(['rock','funk','halftime'], value='rock', label='Pattern')
ev_bpm = gr.Slider(80, 200, value=120, step=2, label='BPM')
ev_bars = gr.Slider(2, 8, value=4, step=1, label='Bars')
ev_ncc = gr.Slider(0.5, 0.99, value=0.80, step=0.01, label='NCC threshold')
ev_btn = gr.Button("πŸ§ͺ Evaluate", variant="primary", size="lg")
with gr.Row():
ev_mix = gr.Audio(type='numpy', label='Original', interactive=False)
ev_rendered = gr.Audio(type='numpy', label='Reconstruction', interactive=False)
ev_summary = gr.Dataframe(label="Summary")
ev_matches = gr.Dataframe(label="Matches")
ev_s1 = gr.Textbox(visible=False); ev_s2 = gr.Textbox(visible=False)
ev_btn.click(run_eval, [ev_pat, ev_bpm, ev_bars, ev_ncc],
[ev_mix, ev_rendered, ev_summary, ev_matches, ev_s1, ev_s2])
# ── Optimize ──
with gr.Tab("πŸ”„ Optimize"):
gr.Markdown("### Autonomous Optimization\nTests across 6 diverse songs, saves best config to Hub.")
with gr.Row():
opt_n = gr.Slider(2, 30, value=5, step=1, label='Iterations')
opt_name = gr.Textbox(value="optimized", label='Config name')
opt_author = gr.Textbox(value="", label='Author')
opt_save = gr.Checkbox(value=True, label='Save to Hub')
opt_btn = gr.Button("πŸš€ Optimize", variant="primary", size="lg")
opt_log = gr.Textbox(label="Log", lines=20, max_lines=40)
opt_hist = gr.Dataframe(label="History")
opt_plot = gr.Plot(label="Progress")
opt_params = gr.Code(label="Best Config", language="json")
opt_btn.click(run_optimize, [opt_n, opt_name, opt_author, opt_save],
[opt_log, opt_hist, opt_plot, opt_params])
# ── Leaderboard ──
with gr.Tab("πŸ† Leaderboard"):
gr.Markdown("### Config Leaderboard")
lb_btn = gr.Button("πŸ”„ Refresh"); lb_tbl = gr.Dataframe()
lb_s = gr.Textbox(visible=False)
lb_btn.click(refresh_leaderboard, [], [lb_tbl, lb_s])
return app
if __name__ == "__main__":
build_app().launch(server_name="0.0.0.0", server_port=7860)