rikhoffbauer2's picture
v8: Lock checkboxes for auto-tune β€” constrain any parameter during search
cadb06d verified
"""
Gradio UI β€” Sample Extractor v8.
Auto-tune with parameter locking.
"""
import gradio as gr
import numpy as np, pandas as pd, json, sys, os, tempfile
import soundfile as sf, librosa
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from sample_extractor import (
extract_stem, detect_onsets, classify_hits,
cluster_hits, select_best, synthesize_from_cluster,
sample_quality_score, export_midi, detect_bpm,
render_midi_with_samples, build_archive, cache_clear, auto_tune,
DEMUCS_MODELS, DEMUCS_STEMS,
)
from synth_generator import generate_test_song
from evaluation import evaluate_extraction
from config_store import PipelineConfig, get_leaderboard
from optimizer_v2 import run_optimization
def audio_tuple(a, sr):
a = a.astype(np.float32); pk = np.abs(a).max()
if pk > 0: a = a / pk * 0.95
return (sr, a)
# ─── Auto-tune with locks ────────────────────────────────────────────────────
def run_auto_tune(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
onset_mode,
# Current values (used when locked)
cur_delta, cur_energy, cur_gap, cur_tmin, cur_tmax,
# Lock flags
lock_delta, lock_energy, lock_gap, lock_targets,
progress=gr.Progress()):
if audio_in is None:
return [gr.update()] * 5 + ["Upload audio first", ""]
# Build locks dict from checkboxes
locks = {}
if lock_delta: locks['onset_delta'] = float(cur_delta)
if lock_energy: locks['energy_threshold_db'] = float(cur_energy)
if lock_gap: locks['min_gap'] = float(cur_gap)
if lock_targets:
locks['target_min'] = int(cur_tmin)
locks['target_max'] = int(cur_tmax)
progress(0.0, desc="Loading audio...")
sr_in, data = audio_in
data = data.astype(np.float32)
if data.ndim > 1: data = data.mean(axis=1)
pk = np.abs(data).max()
if pk > 0: data = data / pk
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
sf.write(f.name, data, sr_in); tmp = f.name
try:
progress(0.05, desc=f"Extracting {stem_choice} stem...")
stem_audio, stem_sr = extract_stem(tmp, stem=stem_choice, device="cpu",
model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
lock_desc = ', '.join(f'{k}={v}' for k, v in locks.items()) if locks else 'none'
progress(0.15, desc=f"Auto-tuning (locked: {lock_desc})...")
best_params, best_score, log_lines = auto_tune(
stem_audio, stem_sr, mode=onset_mode, locks=locks)
progress(1.0, desc=f"Score: {best_score:.1f}")
log_text = '\n'.join(log_lines[-30:])
lock_info = f"πŸ”’ Locked: {lock_desc}" if locks else "No locks β€” all params tuned freely"
summary = (f"**Auto-tune complete!** Score: **{best_score:.1f}/100**\n\n"
f"{lock_info}\n\n"
f"Click **Extract Samples** to run with these settings.")
# Return updated values β€” only update unlocked params
return [
gr.update(value=best_params['onset_delta']) if not lock_delta else gr.update(),
gr.update(value=best_params['energy_threshold_db']) if not lock_energy else gr.update(),
gr.update(value=best_params['min_gap']) if not lock_gap else gr.update(),
gr.update(value=best_params.get('target_min', 5)) if not lock_targets else gr.update(),
gr.update(value=best_params.get('target_max', 20)) if not lock_targets else gr.update(),
summary,
log_text,
]
finally:
os.unlink(tmp)
# ─── Extract ──────────────────────────────────────────────────────────────────
def run_extraction(audio_in, stem_choice, demucs_model, demucs_shifts, demucs_overlap,
onset_mode, onset_delta, energy_db, pre_pad, min_dur, max_dur, min_gap,
ncc_threshold, ncc_compare_ms, linkage, target_min, target_max,
do_synthesize, progress=gr.Progress()):
if audio_in is None: return [None]*8
progress(0.0, desc="Loading...")
sr_in, data = audio_in; data = data.astype(np.float32)
if data.ndim > 1: data = data.mean(axis=1)
pk = np.abs(data).max()
if pk > 0: data = data / pk
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
sf.write(f.name, data, sr_in); tmp = f.name
try:
progress(0.05, desc=f"Stem ({demucs_model})...")
sa, ssr = extract_stem(tmp, stem=stem_choice, device="cpu",
model_name=demucs_model, shifts=int(demucs_shifts), overlap=float(demucs_overlap))
progress(0.15, desc="BPM..."); bpm = detect_bpm(sa, ssr)
progress(0.25, desc="Onsets...")
hits = detect_onsets(sa, ssr, mode=onset_mode, onset_delta=float(onset_delta),
energy_threshold_db=float(energy_db), pre_pad=float(pre_pad),
min_dur=float(min_dur), max_dur=float(max_dur), min_gap=float(min_gap))
if not hits:
return (audio_tuple(sa,ssr), f"**BPM: {bpm}** β€” No hits.", None,None,None,None,"",pd.DataFrame())
progress(0.35, desc="Classify..."); hits = classify_hits(hits)
progress(0.45, desc="Cluster...")
cl = cluster_hits(hits, ncc_threshold=float(ncc_threshold), max_compare_ms=float(ncc_compare_ms),
target_min=int(target_min), target_max=int(target_max), linkage=str(linkage))
progress(0.65, desc="Select..."); select_best(cl)
if do_synthesize:
progress(0.7, desc="Synth...")
for c in cl:
if c.count>=2: c.synthesized=synthesize_from_cluster(c)
progress(0.75, desc="MIDI..."); mp=tempfile.mktemp(suffix='.mid'); export_midi(cl,mp,bpm=bpm)
progress(0.8, desc="Render..."); rend=render_midi_with_samples(cl,sr=ssr)
progress(0.85, desc="Package...")
sd=tempfile.mkdtemp(); sp=[]
for c in sorted(cl,key=lambda x:x.count,reverse=True):
p=os.path.join(sd,f"{c.label}.wav"); c.best_hit.save(p); sp.append(p)
zp=build_archive(cl,bpm,ssr,midi_path=mp,rendered_audio=rend)
rows=[]
for c in sorted(cl,key=lambda x:x.count,reverse=True):
b=c.best_hit; sc=sample_quality_score(b.audio,b.sr,c.label.rsplit('_',1)[0])
rows.append({'Sample':c.label,'Hits':c.count,'MIDI':c.midi_note,
'Score':f"{sc['total']:.1f}",'Clean':f"{sc['cleanness']:.2f}",
'Complete':f"{sc['completeness']:.2f}",
'Dur':f"{b.duration*1000:.0f}ms",
'First':f"{sorted(h.onset_time for h in c.hits)[0]:.2f}s"})
sm=f"**BPM: {bpm}** Β· **{len(cl)} samples** from {len(hits)} hits\n\n"
sm+=f"`{demucs_model}` Β· Ξ΄=`{onset_delta}` Β· E=`{energy_db}dB`"
if int(target_min)>0 and int(target_max)>0: sm+=f" Β· clusters `{int(target_min)}–{int(target_max)}`"
sm+="\n\n| Sample | Hits | MIDI |\n|---|---|---|\n"
for c in sorted(cl,key=lambda x:x.count,reverse=True): sm+=f"| {c.label} | {c.count} | {c.midi_note} |\n"
progress(1.0)
return (audio_tuple(sa,ssr),sm,audio_tuple(rend,ssr),sp,mp,zp,"",pd.DataFrame(rows))
finally: os.unlink(tmp)
# ─── Evaluate ─────────────────────────────────────────────────────────────────
def run_eval(pattern, bpm, bars, ncc_threshold, target_min, target_max, progress=gr.Progress()):
progress(0.0); song=generate_test_song(pattern_name=pattern,bars=int(bars),bpm=float(bpm),variation='medium',seed=42)
dbpm=detect_bpm(song.drums_only,song.sr); progress(0.2)
hits=detect_onsets(song.drums_only,song.sr)
if not hits: return None,None,None,None,"",""
hits=classify_hits(hits)
cl=cluster_hits(hits,ncc_threshold=float(ncc_threshold),target_min=int(target_min),target_max=int(target_max))
select_best(cl)
for c in cl:
if c.count>=2: c.synthesized=synthesize_from_cluster(c)
progress(0.5); rend=render_midi_with_samples(cl,sr=song.sr); progress(0.6)
gt={n:s.audio for n,s in song.samples.items()}
gh=[{'sample':h.sample_name,'onset':h.onset_time,'velocity':h.velocity} for h in song.hits]
r=evaluate_extraction(cl,gt,gh,song.sr,hits)
s=[{'Metric':'BPM','Value':f"{dbpm}",'Target':f"{song.bpm}"},
{'Metric':'Clusters','Value':str(len(cl)),'Target':str(len(gt))},
{'Metric':'Score','Value':f"{r.overall_score:.1f}/100",'Target':'> 70'}]
if r.unmatched_gt: s.append({'Metric':'⚠','Value':', '.join(r.unmatched_gt),'Target':'None'})
m=[{'Cluster':m.cluster_label,'GT':m.gt_name,'Score':f"{m.sample_score:.1f}"} for m in r.matches]
progress(1.0)
return (audio_tuple(song.mix,song.sr),audio_tuple(rend,song.sr),pd.DataFrame(s),pd.DataFrame(m) if m else None,"","")
def run_optimize(n_iters,config_name,author,save_hub,progress=gr.Progress()):
logs=[]; progress(0.0)
state=run_optimization(n_iterations=int(n_iters),config_name=config_name or "opt",
author=author or "anon",save_to_hub=bool(save_hub),log_fn=lambda m:logs.append(m))
progress(1.0)
h=[{'Iter':r.iteration,'Score':f"{r.avg_score:.1f}"} for r in state.history]
if state.history:
fig,ax=plt.subplots(figsize=(10,4)); ax.plot([r.iteration for r in state.history],[r.avg_score for r in state.history],'b-o')
ax.grid(True,alpha=0.3); plt.tight_layout()
else: fig,ax=plt.subplots(); ax.text(0.5,0.5,"No data")
return '\n'.join(logs),pd.DataFrame(h),fig,json.dumps(state.best_config,indent=2)
def refresh_lb():
try:
lb=get_leaderboard(); return pd.DataFrame(lb) if lb else pd.DataFrame(),""
except Exception as e: return pd.DataFrame(),str(e)
# ─── App ──────────────────────────────────────────────────────────────────────
def build_app():
with gr.Blocks(title="🎡 Sample Extractor",theme=gr.themes.Soft(),
css=".gradio-container{max-width:1300px!important} .lock-row{align-items:center}") as app:
gr.Markdown("# 🎡 Sample Extractor v8\n"
"**Auto-Tune** finds optimal parameters for your audio. "
"πŸ”’ **Lock** any parameter to constrain the search.")
with gr.Tabs():
with gr.Tab("🎡 Extract"):
audio_in = gr.Audio(sources=['upload'], type='numpy', label='Upload Audio')
with gr.Accordion("πŸ”§ Stem Separation", open=False):
with gr.Row():
dm=gr.Dropdown(DEMUCS_MODELS,value="htdemucs_ft",label="Model")
st=gr.Dropdown(['drums','bass','other','vocals','all'],value='drums',label='Stem')
dsh=gr.Slider(0,5,value=1,step=1,label='Shifts')
dov=gr.Slider(0.0,0.5,value=0.25,step=0.05,label='Overlap')
with gr.Accordion("🎯 Onset Detection", open=False):
with gr.Row():
om=gr.Dropdown(['auto','percussive','harmonic','broadband'],value='auto',label='Mode')
with gr.Row(elem_classes="lock-row"):
od=gr.Slider(0.01,0.5,value=0.12,step=0.01,label='Delta')
lock_od=gr.Checkbox(value=False,label='πŸ”’',scale=0)
with gr.Row(elem_classes="lock-row"):
ed=gr.Slider(-70,-10,value=-35,step=1,label='Energy (dB)')
lock_ed=gr.Checkbox(value=False,label='πŸ”’',scale=0)
with gr.Row(elem_classes="lock-row"):
mg=gr.Slider(0.005,0.2,value=0.03,step=0.005,label='Min gap (s)')
lock_mg=gr.Checkbox(value=False,label='πŸ”’',scale=0)
with gr.Row():
pp=gr.Slider(0.0,0.05,value=0.005,step=0.001,label='Pre-pad (s)')
mnd=gr.Slider(0.005,0.2,value=0.02,step=0.005,label='Min dur (s)')
mxd=gr.Slider(0.1,5.0,value=1.5,step=0.1,label='Max dur (s)')
with gr.Accordion("πŸ”— Clustering", open=True):
with gr.Row(elem_classes="lock-row"):
tmin=gr.Number(value=5,label='Target min clusters',precision=0)
tmax=gr.Number(value=20,label='Target max clusters',precision=0)
lock_tgt=gr.Checkbox(value=True,label='πŸ”’ Lock range',scale=0)
gr.Markdown("*πŸ”’ = auto-tune will respect this value. Unchecked = auto-tune will change it.*")
with gr.Row():
nt=gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC threshold')
nms=gr.Slider(0,1000,value=0,step=50,label='Compare ms (0=auto)')
lnk=gr.Dropdown(['average','complete','single'],value='average',label='Linkage')
with gr.Accordion("βš™οΈ Post-processing", open=False):
syn=gr.Checkbox(value=True,label='Synthesize optimal samples')
with gr.Row():
tune_btn=gr.Button("πŸŽ›οΈ Auto-Tune",variant="secondary",size="lg")
extract_btn=gr.Button("πŸ”¬ Extract Samples",variant="primary",size="lg")
tune_summary=gr.Markdown("")
tune_log=gr.Textbox(label="Auto-tune log",lines=8,max_lines=15,visible=False)
summary_md=gr.Markdown("*Upload audio β†’ Auto-Tune or Extract*")
with gr.Row():
stem_out=gr.Audio(type='numpy',label='Stem',interactive=False)
rend_out=gr.Audio(type='numpy',label='πŸ”Š Reconstruction',interactive=False)
gr.Markdown("### Downloads")
with gr.Row():
arc=gr.File(label="πŸ“¦ ZIP",interactive=False)
mid=gr.File(label="🎹 MIDI",interactive=False)
smp=gr.File(label="WAV samples",file_count="multiple",interactive=False)
met=gr.Dataframe(label="Samples")
stx=gr.Textbox(visible=False)
dm.change(fn=lambda m:gr.update(choices=DEMUCS_STEMS.get(m,["drums","bass","other","vocals"])+["all"]),
inputs=[dm],outputs=[st])
tune_btn.click(run_auto_tune,
[audio_in, st, dm, dsh, dov, om,
od, ed, mg, tmin, tmax, # current values
lock_od, lock_ed, lock_mg, lock_tgt], # lock flags
[od, ed, mg, tmin, tmax, tune_summary, tune_log])
extract_btn.click(run_extraction,
[audio_in,st,dm,dsh,dov,om,od,ed,pp,mnd,mxd,mg,nt,nms,lnk,tmin,tmax,syn],
[stem_out,summary_md,rend_out,smp,mid,arc,stx,met])
with gr.Tab("πŸ“Š Evaluate"):
gr.Markdown("Synthetic evaluation.")
with gr.Row():
ep=gr.Dropdown(['rock','funk','halftime'],value='rock',label='Pattern')
eb=gr.Slider(80,200,value=120,step=2,label='BPM')
ebs=gr.Slider(2,8,value=4,step=1,label='Bars')
with gr.Row():
en=gr.Slider(0.3,0.99,value=0.80,step=0.01,label='NCC')
etm=gr.Number(value=0,label='Min',precision=0)
etx=gr.Number(value=0,label='Max',precision=0)
evb=gr.Button("πŸ§ͺ Evaluate",variant="primary",size="lg")
with gr.Row():
evm=gr.Audio(type='numpy',label='Original',interactive=False)
evr=gr.Audio(type='numpy',label='Reconstruction',interactive=False)
evs=gr.Dataframe(label="Summary"); evm2=gr.Dataframe(label="Matches")
es1=gr.Textbox(visible=False); es2=gr.Textbox(visible=False)
evb.click(run_eval,[ep,eb,ebs,en,etm,etx],[evm,evr,evs,evm2,es1,es2])
with gr.Tab("πŸ”„ Optimize"):
gr.Markdown("### Synthetic optimization")
with gr.Row():
on=gr.Slider(2,30,value=5,step=1,label='Iters')
ocn=gr.Textbox(value="opt",label='Name')
oa=gr.Textbox(value="",label='Author')
osv=gr.Checkbox(value=True,label='Save')
ob=gr.Button("πŸš€ Run",variant="primary",size="lg")
ol=gr.Textbox(label="Log",lines=20,max_lines=40)
oh=gr.Dataframe(label="History"); op=gr.Plot()
oc=gr.Code(label="Config",language="json")
ob.click(run_optimize,[on,ocn,oa,osv],[ol,oh,op,oc])
with gr.Tab("πŸ† Leaderboard"):
lbb=gr.Button("πŸ”„ Refresh"); lt=gr.Dataframe(); ls=gr.Textbox(visible=False)
lbb.click(refresh_lb,[],[lt,ls])
return app
if __name__ == "__main__":
build_app().launch(server_name="0.0.0.0", server_port=7860)