""" Gradio UI — Sample Extractor v4. NCC clustering, full parameter control, Demucs model selection. """ import gradio as gr import numpy as np, pandas as pd, json, sys, os, tempfile import soundfile as sf, librosa import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from sample_extractor import ( extract_stem, detect_onsets, classify_hits, cluster_hits, select_best, synthesize_from_cluster, sample_quality_score, export_midi, detect_bpm, render_midi_with_samples, build_archive, cache_clear, DEMUCS_MODELS, DEMUCS_STEMS, ) from synth_generator import generate_test_song from evaluation import evaluate_extraction from config_store import PipelineConfig, get_leaderboard from optimizer_v2 import run_optimization def audio_tuple(a, sr): a = a.astype(np.float32) pk = np.abs(a).max() if pk > 0: a = a / pk * 0.95 return (sr, a) # ─── Tab 1: Extract ────────────────────────────────────────────────────────── def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap, onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap, ncc_threshold, ncc_compare_ms, linkage, target_min, target_max, do_synthesize, progress=gr.Progress()): if audio_in is None: return [None] * 8 progress(0.0, desc="Loading audio...") sr_in, data = audio_in data = data.astype(np.float32) if data.ndim > 1: data = data.mean(axis=1) pk = np.abs(data).max() if pk > 0: data = data / pk with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: sf.write(f.name, data, sr_in); tmp = f.name try: progress(0.05, desc=f"Extracting {stem_choice} stem ({demucs_model})...") stem_audio, stem_sr = extract_stem( tmp, stem=stem_choice, device="cpu", model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap)) progress(0.15, desc="Detecting BPM...") bpm = detect_bpm(stem_audio, stem_sr) progress(0.25, desc="Detecting onsets...") hits = detect_onsets(stem_audio, stem_sr, mode=onset_mode, onset_delta=float(onset_delta), energy_threshold_db=float(energy_db), pre_pad=float(pre_pad), min_dur=float(min_dur), max_dur=float(max_dur), min_gap=float(min_gap)) if not hits: return (audio_tuple(stem_audio, stem_sr), f"**Detected BPM: {bpm}**\n\nNo hits found — try lowering energy threshold.", None, None, None, None, "", pd.DataFrame()) progress(0.35, desc="Classifying...") hits = classify_hits(hits) progress(0.45, desc=f"NCC clustering...") clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold), max_compare_ms=float(ncc_compare_ms), target_min=int(target_min), target_max=int(target_max), linkage=str(linkage)) progress(0.65, desc="Scoring & selecting best...") select_best(clusters) if do_synthesize: progress(0.7, desc="Synthesizing...") for c in clusters: if c.count >= 2: c.synthesized = synthesize_from_cluster(c) progress(0.75, desc="Building MIDI...") midi_path = tempfile.mktemp(suffix='.mid') export_midi(clusters, midi_path, bpm=bpm) progress(0.8, desc="Rendering reconstruction...") rendered = render_midi_with_samples(clusters, sr=stem_sr) progress(0.85, desc="Packaging...") sample_dir = tempfile.mkdtemp() sample_paths = [] for c in sorted(clusters, key=lambda x: x.count, reverse=True): sp = os.path.join(sample_dir, f"{c.label}.wav") c.best_hit.save(sp); sample_paths.append(sp) zip_path = build_archive(clusters, bpm, stem_sr, midi_path=midi_path, rendered_audio=rendered) # Metrics rows = [] for c in sorted(clusters, key=lambda x: x.count, reverse=True): best = c.best_hit sc = sample_quality_score(best.audio, best.sr, c.label.rsplit('_',1)[0]) rows.append({ 'Sample': c.label, 'Hits': c.count, 'MIDI': c.midi_note, 'Score': f"{sc['total']:.1f}", 'Clean': f"{sc['cleanness']:.2f}", 'Complete': f"{sc['completeness']:.2f}", 'Dur (ms)': f"{best.duration*1000:.0f}", 'First @': f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s", }) summary = f"**Detected BPM: {bpm}** · **{len(clusters)} unique samples** from {len(hits)} hits\n\n" summary += f"Model: `{demucs_model}` · NCC threshold: `{ncc_threshold}` · Onset delta: `{onset_delta}`\n\n" summary += "| Sample | Hits | MIDI Note |\n|---|---|---|\n" for c in sorted(clusters, key=lambda x: x.count, reverse=True): summary += f"| {c.label} | {c.count} | {c.midi_note} |\n" progress(1.0, desc="Done!") return (audio_tuple(stem_audio, stem_sr), summary, audio_tuple(rendered, stem_sr), sample_paths, midi_path, zip_path, "", pd.DataFrame(rows)) finally: os.unlink(tmp) # ─── Tab 2: Evaluate ───────────────────────────────────────────────────────── def run_eval(pattern, bpm, bars, ncc_threshold, progress=gr.Progress()): progress(0.0, desc="Generating synthetic song...") song = generate_test_song(pattern_name=pattern, bars=int(bars), bpm=float(bpm), variation='medium', seed=42) detected_bpm = detect_bpm(song.drums_only, song.sr) progress(0.2, desc="Extracting...") hits = detect_onsets(song.drums_only, song.sr) if not hits: return None, None, None, None, "", "" hits = classify_hits(hits) clusters = cluster_hits(hits, ncc_threshold=float(ncc_threshold)) select_best(clusters) for c in clusters: if c.count >= 2: c.synthesized = synthesize_from_cluster(c) progress(0.5, desc="Rendering...") rendered = render_midi_with_samples(clusters, sr=song.sr) progress(0.6, desc="Evaluating...") gt = {n: s.audio for n, s in song.samples.items()} gt_hits = [{'sample': h.sample_name, 'onset': h.onset_time, 'velocity': h.velocity} for h in song.hits] report = evaluate_extraction(clusters, gt, gt_hits, song.sr, hits) mix_out = audio_tuple(song.mix, song.sr) rendered_out = audio_tuple(rendered, song.sr) summary = [ {'Metric': 'Detected BPM', 'Value': f"{detected_bpm}", 'Target': f"{song.bpm}"}, {'Metric': 'Clusters', 'Value': str(len(clusters)), 'Target': str(len(gt))}, {'Metric': 'Overall', 'Value': f"{report.overall_score:.1f}/100", 'Target': '> 70'}, {'Metric': 'SI-SDR', 'Value': f"{report.mean_si_sdr:.1f} dB", 'Target': '> 10'}, {'Metric': 'Env Corr', 'Value': f"{report.mean_env_corr:.3f}", 'Target': '> 0.9'}, ] if report.unmatched_gt: summary.append({'Metric': '⚠ Unmatched', 'Value': ', '.join(report.unmatched_gt), 'Target': 'None'}) matches = [{'Cluster': m.cluster_label, 'GT': m.gt_name, 'SI-SDR': f"{m.si_sdr:.1f}", 'Score': f"{m.sample_score:.1f}"} for m in report.matches] progress(1.0) return (mix_out, rendered_out, pd.DataFrame(summary), pd.DataFrame(matches) if matches else None, "", "") # ─── Tab 3: Optimize ───────────────────────────────────────────────────────── def run_optimize(n_iters, config_name, author, save_hub, progress=gr.Progress()): logs = [] progress(0.0, desc="Starting optimization...") state = run_optimization(n_iterations=int(n_iters), config_name=config_name or "optimized", author=author or "anonymous", save_to_hub=bool(save_hub), log_fn=lambda m: logs.append(m)) progress(1.0) hist = [{'Iter': r.iteration, 'Score': f"{r.avg_score:.1f}", 'Time': f"{r.duration_s:.1f}s"} for r in state.history] if state.history: fig, ax = plt.subplots(figsize=(10, 4)) ax.plot([r.iteration for r in state.history], [r.avg_score for r in state.history], 'b-o') ax.set_xlabel('Iteration'); ax.set_ylabel('Score'); ax.grid(True, alpha=0.3); plt.tight_layout() else: fig, ax = plt.subplots(); ax.text(0.5, 0.5, "No data") return '\n'.join(logs), pd.DataFrame(hist), fig, json.dumps(state.best_config, indent=2) # ─── Tab 4: Leaderboard ────────────────────────────────────────────────────── def refresh_leaderboard(): try: lb = get_leaderboard() return pd.DataFrame(lb) if lb else pd.DataFrame(), "" except Exception as e: return pd.DataFrame(), str(e) # ─── Build App ──────────────────────────────────────────────────────────────── def get_stems_for_model(model_name): stems = DEMUCS_STEMS.get(model_name, ["drums", "bass", "other", "vocals"]) return gr.update(choices=stems + ["all"], value=stems[0]) def build_app(): with gr.Blocks(title="🎵 Sample Extractor", theme=gr.themes.Soft(), css=".gradio-container{max-width:1300px!important}") as app: gr.Markdown("# 🎵 Sample Extractor v4\n" "Extract distinct sounds from audio using **NCC waveform matching** — " "correctly groups identical samples regardless of velocity.\n" "Full control over Demucs model, onset detection, and clustering parameters.") with gr.Tabs(): # ── Extract ── with gr.Tab("🎵 Extract"): audio_in = gr.Audio(sources=['upload'], type='numpy', label='Upload Audio') with gr.Accordion("🔧 Stem Separation", open=False): with gr.Row(): demucs_model = gr.Dropdown(DEMUCS_MODELS, value="htdemucs_ft", label="Demucs Model") stem_dd = gr.Dropdown(['drums','bass','other','vocals','all'], value='drums', label='Stem') demucs_shifts = gr.Slider(0, 5, value=1, step=1, label='Shifts (TTA, 0=fastest)') demucs_overlap = gr.Slider(0.0, 0.5, value=0.25, step=0.05, label='Overlap') with gr.Accordion("🎯 Onset Detection", open=False): with gr.Row(): onset_mode = gr.Dropdown(['auto','percussive','harmonic','broadband'], value='auto', label='Mode') onset_delta = gr.Slider(0.01, 0.5, value=0.07, step=0.01, label='Delta (sensitivity)') energy_db = gr.Slider(-70, -10, value=-45, step=1, label='Energy threshold (dB)') with gr.Row(): pre_pad = gr.Slider(0.0, 0.05, value=0.005, step=0.001, label='Pre-pad (s)') min_dur = gr.Slider(0.005, 0.2, value=0.02, step=0.005, label='Min duration (s)') max_dur = gr.Slider(0.1, 5.0, value=1.5, step=0.1, label='Max duration (s)') min_gap = gr.Slider(0.005, 0.2, value=0.015, step=0.005, label='Min gap (s)') with gr.Accordion("🔗 Clustering", open=False): with gr.Row(): ncc_thresh = gr.Slider(0.3, 0.99, value=0.80, step=0.01, label='NCC threshold (higher = stricter)') ncc_ms = gr.Slider(50, 1000, value=200, step=50, label='Compare window (ms)') linkage_dd = gr.Dropdown(['average', 'complete', 'single'], value='average', label='Linkage') with gr.Row(): target_min = gr.Number(value=0, label='Target min clusters (0 = use threshold)', precision=0) target_max = gr.Number(value=0, label='Target max clusters (0 = use threshold)', precision=0) gr.Markdown("*Set both target min/max > 0 to auto-search for the right threshold. " "Leave at 0 to use the NCC threshold directly.*") with gr.Accordion("⚙️ Post-processing", open=False): do_synth = gr.Checkbox(value=True, label='Synthesize optimal samples from clusters') extract_btn = gr.Button("🔬 Extract Samples", variant="primary", size="lg") summary_md = gr.Markdown("*Upload audio and click Extract*") with gr.Row(): stem_out = gr.Audio(type='numpy', label='Stem', interactive=False) rendered_out = gr.Audio(type='numpy', label='🔊 Reconstruction', interactive=False) gr.Markdown("### Downloads") with gr.Row(): archive_file = gr.File(label="📦 ZIP Archive", interactive=False) midi_file = gr.File(label="🎹 MIDI", interactive=False) sample_files = gr.File(label="Individual WAV samples", file_count="multiple", interactive=False) metrics_tbl = gr.Dataframe(label="Extracted Samples") status_txt = gr.Textbox(visible=False) # Update available stems when model changes demucs_model.change( fn=lambda m: gr.update(choices=DEMUCS_STEMS.get(m, ["drums","bass","other","vocals"]) + ["all"]), inputs=[demucs_model], outputs=[stem_dd]) extract_btn.click( run_extraction, [audio_in, stem_dd, demucs_model, demucs_shifts, demucs_overlap, onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap, ncc_thresh, ncc_ms, linkage_dd, target_min, target_max, do_synth], [stem_out, summary_md, rendered_out, sample_files, midi_file, archive_file, status_txt, metrics_tbl]) # ── Evaluate ── with gr.Tab("📊 Evaluate"): gr.Markdown("Generate synthetic song → extract → compare to ground truth.") with gr.Row(): ev_pat = gr.Dropdown(['rock','funk','halftime'], value='rock', label='Pattern') ev_bpm = gr.Slider(80, 200, value=120, step=2, label='BPM') ev_bars = gr.Slider(2, 8, value=4, step=1, label='Bars') ev_ncc = gr.Slider(0.5, 0.99, value=0.80, step=0.01, label='NCC threshold') ev_btn = gr.Button("🧪 Evaluate", variant="primary", size="lg") with gr.Row(): ev_mix = gr.Audio(type='numpy', label='Original', interactive=False) ev_rendered = gr.Audio(type='numpy', label='Reconstruction', interactive=False) ev_summary = gr.Dataframe(label="Summary") ev_matches = gr.Dataframe(label="Matches") ev_s1 = gr.Textbox(visible=False); ev_s2 = gr.Textbox(visible=False) ev_btn.click(run_eval, [ev_pat, ev_bpm, ev_bars, ev_ncc], [ev_mix, ev_rendered, ev_summary, ev_matches, ev_s1, ev_s2]) # ── Optimize ── with gr.Tab("🔄 Optimize"): gr.Markdown("### Autonomous Optimization\nTests across 6 diverse songs, saves best config to Hub.") with gr.Row(): opt_n = gr.Slider(2, 30, value=5, step=1, label='Iterations') opt_name = gr.Textbox(value="optimized", label='Config name') opt_author = gr.Textbox(value="", label='Author') opt_save = gr.Checkbox(value=True, label='Save to Hub') opt_btn = gr.Button("🚀 Optimize", variant="primary", size="lg") opt_log = gr.Textbox(label="Log", lines=20, max_lines=40) opt_hist = gr.Dataframe(label="History") opt_plot = gr.Plot(label="Progress") opt_params = gr.Code(label="Best Config", language="json") opt_btn.click(run_optimize, [opt_n, opt_name, opt_author, opt_save], [opt_log, opt_hist, opt_plot, opt_params]) # ── Leaderboard ── with gr.Tab("🏆 Leaderboard"): gr.Markdown("### Config Leaderboard") lb_btn = gr.Button("🔄 Refresh"); lb_tbl = gr.Dataframe() lb_s = gr.Textbox(visible=False) lb_btn.click(refresh_leaderboard, [], [lb_tbl, lb_s]) return app if __name__ == "__main__": build_app().launch(server_name="0.0.0.0", server_port=7860)