| """ |
| Gradio UI β Sample Extractor v4. |
| NCC clustering, full parameter control, Demucs model selection. |
| """ |
|
|
| import gradio as gr |
| import numpy as np, pandas as pd, json, sys, os, tempfile |
| import soundfile as sf, librosa |
| import matplotlib; matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from sample_extractor import ( |
| extract_stem, detect_onsets, classify_hits, |
| cluster_hits, select_best, synthesize_from_cluster, |
| sample_quality_score, export_midi, detect_bpm, |
| render_midi_with_samples, build_archive, cache_clear, |
| DEMUCS_MODELS, DEMUCS_STEMS, |
| ) |
| from synth_generator import generate_test_song |
| from evaluation import evaluate_extraction |
| from config_store import PipelineConfig, get_leaderboard |
| from optimizer_v2 import run_optimization |
|
|
|
|
| def audio_tuple(a, sr): |
| a = a.astype(np.float32) |
| pk = np.abs(a).max() |
| if pk > 0: a = a / pk * 0.95 |
| return (sr, a) |
|
|
|
|
| |
|
|
| def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap, |
| onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap, |
| ncc_threshold, ncc_compare_ms, linkage, target_min, target_max, |
| do_synthesize, progress=gr.Progress()): |
| if audio_in is None: |
| return [None] * 8 |
|
|
| progress(0.0, desc="Loading audio...") |
| sr_in, data = audio_in |
| data = data.astype(np.float32) |
| if data.ndim > 1: data = data.mean(axis=1) |
| pk = np.abs(data).max() |
| if pk > 0: data = data / pk |
|
|
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: |
| sf.write(f.name, data, sr_in); tmp = f.name |
|
|
| try: |
| progress(0.05, desc=f"Extracting {stem_choice} stem ({demucs_model})...") |
| stem_audio, stem_sr = extract_stem( |
| tmp, stem=stem_choice, device="cpu", |
| model_name=demucs_model, shifts=int(demucs_shifts), |
| overlap=float(demucs_overlap)) |
|
|
| progress(0.15, desc="Detecting BPM...") |
| bpm = detect_bpm(stem_audio, stem_sr) |
|
|
| progress(0.25, desc="Detecting onsets...") |
| hits = detect_onsets(stem_audio, stem_sr, mode=onset_mode, |
| onset_delta=float(onset_delta), |
| energy_threshold_db=float(energy_db), |
| pre_pad=float(pre_pad), min_dur=float(min_dur), |
| max_dur=float(max_dur), min_gap=float(min_gap)) |
| if not hits: |
| return (audio_tuple(stem_audio, stem_sr), |
| f"**Detected BPM: {bpm}**\n\nNo hits found β try lowering energy threshold.", |
| None, None, None, None, "", pd.DataFrame()) |
|
|
| progress(0.35, desc="Classifying...") |
| hits = classify_hits(hits) |
|
|
| progress(0.45, desc=f"NCC clustering...") |
| clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold), |
| max_compare_ms=float(ncc_compare_ms), |
| target_min=int(target_min), target_max=int(target_max), |
| linkage=str(linkage)) |
|
|
| progress(0.65, desc="Scoring & selecting best...") |
| select_best(clusters) |
|
|
| if do_synthesize: |
| progress(0.7, desc="Synthesizing...") |
| for c in clusters: |
| if c.count >= 2: c.synthesized = synthesize_from_cluster(c) |
|
|
| progress(0.75, desc="Building MIDI...") |
| midi_path = tempfile.mktemp(suffix='.mid') |
| export_midi(clusters, midi_path, bpm=bpm) |
|
|
| progress(0.8, desc="Rendering reconstruction...") |
| rendered = render_midi_with_samples(clusters, sr=stem_sr) |
|
|
| progress(0.85, desc="Packaging...") |
| sample_dir = tempfile.mkdtemp() |
| sample_paths = [] |
| for c in sorted(clusters, key=lambda x: x.count, reverse=True): |
| sp = os.path.join(sample_dir, f"{c.label}.wav") |
| c.best_hit.save(sp); sample_paths.append(sp) |
|
|
| zip_path = build_archive(clusters, bpm, stem_sr, |
| midi_path=midi_path, rendered_audio=rendered) |
|
|
| |
| rows = [] |
| for c in sorted(clusters, key=lambda x: x.count, reverse=True): |
| best = c.best_hit |
| sc = sample_quality_score(best.audio, best.sr, c.label.rsplit('_',1)[0]) |
| rows.append({ |
| 'Sample': c.label, 'Hits': c.count, 'MIDI': c.midi_note, |
| 'Score': f"{sc['total']:.1f}", |
| 'Clean': f"{sc['cleanness']:.2f}", |
| 'Complete': f"{sc['completeness']:.2f}", |
| 'Dur (ms)': f"{best.duration*1000:.0f}", |
| 'First @': f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s", |
| }) |
|
|
| summary = f"**Detected BPM: {bpm}** Β· **{len(clusters)} unique samples** from {len(hits)} hits\n\n" |
| summary += f"Model: `{demucs_model}` Β· NCC threshold: `{ncc_threshold}` Β· Onset delta: `{onset_delta}`\n\n" |
| summary += "| Sample | Hits | MIDI Note |\n|---|---|---|\n" |
| for c in sorted(clusters, key=lambda x: x.count, reverse=True): |
| summary += f"| {c.label} | {c.count} | {c.midi_note} |\n" |
|
|
| progress(1.0, desc="Done!") |
| return (audio_tuple(stem_audio, stem_sr), summary, |
| audio_tuple(rendered, stem_sr), sample_paths, |
| midi_path, zip_path, "", pd.DataFrame(rows)) |
| finally: |
| os.unlink(tmp) |
|
|
|
|
| |
|
|
| def run_eval(pattern, bpm, bars, ncc_threshold, progress=gr.Progress()): |
| progress(0.0, desc="Generating synthetic song...") |
| song = generate_test_song(pattern_name=pattern, bars=int(bars), |
| bpm=float(bpm), variation='medium', seed=42) |
|
|
| detected_bpm = detect_bpm(song.drums_only, song.sr) |
|
|
| progress(0.2, desc="Extracting...") |
| hits = detect_onsets(song.drums_only, song.sr) |
| if not hits: return None, None, None, None, "", "" |
|
|
| hits = classify_hits(hits) |
| clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold)) |
| select_best(clusters) |
| for c in clusters: |
| if c.count >= 2: c.synthesized = synthesize_from_cluster(c) |
|
|
| progress(0.5, desc="Rendering...") |
| rendered = render_midi_with_samples(clusters, sr=song.sr) |
|
|
| progress(0.6, desc="Evaluating...") |
| gt = {n: s.audio for n, s in song.samples.items()} |
| gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} |
| for h in song.hits] |
| report = evaluate_extraction(clusters, gt, gt_hits, song.sr, hits) |
|
|
| mix_out = audio_tuple(song.mix, song.sr) |
| rendered_out = audio_tuple(rendered, song.sr) |
|
|
| summary = [ |
| {'Metric': 'Detected BPM', 'Value': f"{detected_bpm}", 'Target': f"{song.bpm}"}, |
| {'Metric': 'Clusters', 'Value': str(len(clusters)), 'Target': str(len(gt))}, |
| {'Metric': 'Overall', 'Value': f"{report.overall_score:.1f}/100", 'Target': '> 70'}, |
| {'Metric': 'SI-SDR', 'Value': f"{report.mean_si_sdr:.1f} dB", 'Target': '> 10'}, |
| {'Metric': 'Env Corr', 'Value': f"{report.mean_env_corr:.3f}", 'Target': '> 0.9'}, |
| ] |
| if report.unmatched_gt: |
| summary.append({'Metric': 'β Unmatched', 'Value': ', '.join(report.unmatched_gt), 'Target': 'None'}) |
|
|
| matches = [{'Cluster': m.cluster_label, 'GT': m.gt_name, 'SI-SDR': f"{m.si_sdr:.1f}", |
| 'Score': f"{m.sample_score:.1f}"} for m in report.matches] |
|
|
| progress(1.0) |
| return (mix_out, rendered_out, pd.DataFrame(summary), |
| pd.DataFrame(matches) if matches else None, "", "") |
|
|
|
|
| |
|
|
| def run_optimize(n_iters, config_name, author, save_hub, progress=gr.Progress()): |
| logs = [] |
| progress(0.0, desc="Starting optimization...") |
| state = run_optimization(n_iterations=int(n_iters), |
| config_name=config_name or "optimized", |
| author=author or "anonymous", |
| save_to_hub=bool(save_hub), log_fn=lambda m: logs.append(m)) |
| progress(1.0) |
| hist = [{'Iter': r.iteration, 'Score': f"{r.avg_score:.1f}", |
| 'Time': f"{r.duration_s:.1f}s"} for r in state.history] |
| if state.history: |
| fig, ax = plt.subplots(figsize=(10, 4)) |
| ax.plot([r.iteration for r in state.history], [r.avg_score for r in state.history], 'b-o') |
| ax.set_xlabel('Iteration'); ax.set_ylabel('Score'); ax.grid(True, alpha=0.3); plt.tight_layout() |
| else: |
| fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No data") |
| return '\n'.join(logs), pd.DataFrame(hist), fig, json.dumps(state.best_config, indent=2) |
|
|
|
|
| |
|
|
| def refresh_leaderboard(): |
| try: |
| lb = get_leaderboard() |
| return pd.DataFrame(lb) if lb else pd.DataFrame(), "" |
| except Exception as e: |
| return pd.DataFrame(), str(e) |
|
|
|
|
| |
|
|
| def get_stems_for_model(model_name): |
| stems = DEMUCS_STEMS.get(model_name, ["drums", "bass", "other", "vocals"]) |
| return gr.update(choices=stems + ["all"], value=stems[0]) |
|
|
| def build_app(): |
| with gr.Blocks(title="π΅ Sample Extractor", theme=gr.themes.Soft(), |
| css=".gradio-container{max-width:1300px!important}") as app: |
| gr.Markdown("# π΅ Sample Extractor v4\n" |
| "Extract distinct sounds from audio using **NCC waveform matching** β " |
| "correctly groups identical samples regardless of velocity.\n" |
| "Full control over Demucs model, onset detection, and clustering parameters.") |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("π΅ Extract"): |
| audio_in = gr.Audio(sources=['upload'], type='numpy', label='Upload Audio') |
|
|
| with gr.Accordion("π§ Stem Separation", open=False): |
| with gr.Row(): |
| demucs_model = gr.Dropdown(DEMUCS_MODELS, value="htdemucs_ft", |
| label="Demucs Model") |
| stem_dd = gr.Dropdown(['drums','bass','other','vocals','all'], |
| value='drums', label='Stem') |
| demucs_shifts = gr.Slider(0, 5, value=1, step=1, |
| label='Shifts (TTA, 0=fastest)') |
| demucs_overlap = gr.Slider(0.0, 0.5, value=0.25, step=0.05, |
| label='Overlap') |
|
|
| with gr.Accordion("π― Onset Detection", open=False): |
| with gr.Row(): |
| onset_mode = gr.Dropdown(['auto','percussive','harmonic','broadband'], |
| value='auto', label='Mode') |
| onset_delta = gr.Slider(0.01, 0.5, value=0.07, step=0.01, |
| label='Delta (sensitivity)') |
| energy_db = gr.Slider(-70, -10, value=-45, step=1, |
| label='Energy threshold (dB)') |
| with gr.Row(): |
| pre_pad = gr.Slider(0.0, 0.05, value=0.005, step=0.001, |
| label='Pre-pad (s)') |
| min_dur = gr.Slider(0.005, 0.2, value=0.02, step=0.005, |
| label='Min duration (s)') |
| max_dur = gr.Slider(0.1, 5.0, value=1.5, step=0.1, |
| label='Max duration (s)') |
| min_gap = gr.Slider(0.005, 0.2, value=0.015, step=0.005, |
| label='Min gap (s)') |
|
|
| with gr.Accordion("π Clustering", open=False): |
| with gr.Row(): |
| ncc_thresh = gr.Slider(0.3, 0.99, value=0.80, step=0.01, |
| label='NCC threshold (higher = stricter)') |
| ncc_ms = gr.Slider(50, 1000, value=200, step=50, |
| label='Compare window (ms)') |
| linkage_dd = gr.Dropdown(['average', 'complete', 'single'], |
| value='average', label='Linkage') |
| with gr.Row(): |
| target_min = gr.Number(value=0, label='Target min clusters (0 = use threshold)', |
| precision=0) |
| target_max = gr.Number(value=0, label='Target max clusters (0 = use threshold)', |
| precision=0) |
| gr.Markdown("*Set both target min/max > 0 to auto-search for the right threshold. " |
| "Leave at 0 to use the NCC threshold directly.*") |
|
|
| with gr.Accordion("βοΈ Post-processing", open=False): |
| do_synth = gr.Checkbox(value=True, label='Synthesize optimal samples from clusters') |
|
|
| extract_btn = gr.Button("π¬ Extract Samples", variant="primary", size="lg") |
|
|
| summary_md = gr.Markdown("*Upload audio and click Extract*") |
| with gr.Row(): |
| stem_out = gr.Audio(type='numpy', label='Stem', interactive=False) |
| rendered_out = gr.Audio(type='numpy', label='π Reconstruction', interactive=False) |
|
|
| gr.Markdown("### Downloads") |
| with gr.Row(): |
| archive_file = gr.File(label="π¦ ZIP Archive", interactive=False) |
| midi_file = gr.File(label="πΉ MIDI", interactive=False) |
| sample_files = gr.File(label="Individual WAV samples", file_count="multiple", |
| interactive=False) |
| metrics_tbl = gr.Dataframe(label="Extracted Samples") |
| status_txt = gr.Textbox(visible=False) |
|
|
| |
| demucs_model.change( |
| fn=lambda m: gr.update(choices=DEMUCS_STEMS.get(m, ["drums","bass","other","vocals"]) + ["all"]), |
| inputs=[demucs_model], outputs=[stem_dd]) |
|
|
| extract_btn.click( |
| run_extraction, |
| [audio_in, stem_dd, demucs_model, demucs_shifts, demucs_overlap, |
| onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap, |
| ncc_thresh, ncc_ms, linkage_dd, target_min, target_max, do_synth], |
| [stem_out, summary_md, rendered_out, sample_files, |
| midi_file, archive_file, status_txt, metrics_tbl]) |
|
|
| |
| with gr.Tab("π Evaluate"): |
| gr.Markdown("Generate synthetic song β extract β compare to ground truth.") |
| with gr.Row(): |
| ev_pat = gr.Dropdown(['rock','funk','halftime'], value='rock', label='Pattern') |
| ev_bpm = gr.Slider(80, 200, value=120, step=2, label='BPM') |
| ev_bars = gr.Slider(2, 8, value=4, step=1, label='Bars') |
| ev_ncc = gr.Slider(0.5, 0.99, value=0.80, step=0.01, label='NCC threshold') |
| ev_btn = gr.Button("π§ͺ Evaluate", variant="primary", size="lg") |
| with gr.Row(): |
| ev_mix = gr.Audio(type='numpy', label='Original', interactive=False) |
| ev_rendered = gr.Audio(type='numpy', label='Reconstruction', interactive=False) |
| ev_summary = gr.Dataframe(label="Summary") |
| ev_matches = gr.Dataframe(label="Matches") |
| ev_s1 = gr.Textbox(visible=False); ev_s2 = gr.Textbox(visible=False) |
| ev_btn.click(run_eval, [ev_pat, ev_bpm, ev_bars, ev_ncc], |
| [ev_mix, ev_rendered, ev_summary, ev_matches, ev_s1, ev_s2]) |
|
|
| |
| with gr.Tab("π Optimize"): |
| gr.Markdown("### Autonomous Optimization\nTests across 6 diverse songs, saves best config to Hub.") |
| with gr.Row(): |
| opt_n = gr.Slider(2, 30, value=5, step=1, label='Iterations') |
| opt_name = gr.Textbox(value="optimized", label='Config name') |
| opt_author = gr.Textbox(value="", label='Author') |
| opt_save = gr.Checkbox(value=True, label='Save to Hub') |
| opt_btn = gr.Button("π Optimize", variant="primary", size="lg") |
| opt_log = gr.Textbox(label="Log", lines=20, max_lines=40) |
| opt_hist = gr.Dataframe(label="History") |
| opt_plot = gr.Plot(label="Progress") |
| opt_params = gr.Code(label="Best Config", language="json") |
| opt_btn.click(run_optimize, [opt_n, opt_name, opt_author, opt_save], |
| [opt_log, opt_hist, opt_plot, opt_params]) |
|
|
| |
| with gr.Tab("π Leaderboard"): |
| gr.Markdown("### Config Leaderboard") |
| lb_btn = gr.Button("π Refresh"); lb_tbl = gr.Dataframe() |
| lb_s = gr.Textbox(visible=False) |
| lb_btn.click(refresh_leaderboard, [], [lb_tbl, lb_s]) |
|
|
| return app |
|
|
| if __name__ == "__main__": |
| build_app().launch(server_name="0.0.0.0", server_port=7860) |
|
|