rikhoffbauer2 commited on
Commit
39d4239
·
verified ·
1 Parent(s): 63565aa

v7: Add auto-tuner — self-supervised param optimization using reconstruction quality

Browse files
Files changed (1) hide show
  1. sample_extractor.py +342 -254
sample_extractor.py CHANGED
@@ -1,12 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- Sample Extractor v6Tested on real hardstyle tracks.
4
 
5
- Fixes from v5:
6
- - NCC compare window auto-scales to median hit length (no more zero-pad inflation)
7
- - Target range uses n_clusters directly when binary search hits a cliff
8
- - Better defaults for real music (delta=0.12, energy=-35, min_gap=0.03)
9
- - Caching for stem separation, BPM, NCC distance matrix
10
  """
11
 
12
  import argparse, json, os, sys, warnings, hashlib
@@ -64,88 +63,74 @@ def cache_clear(): _cache.clear()
64
 
65
  # ─── Stage 1: Stem separation ────────────────────────────────────────────────
66
 
67
- def extract_stem(audio_path: str, stem: str = "drums", device: str = "cpu",
68
- model_name: str = "htdemucs_ft", shifts: int = 1,
69
- overlap: float = 0.25) -> tuple:
70
  if stem == "all":
71
  y, sr = librosa.load(audio_path, sr=44100, mono=True)
72
  return y.astype(np.float32), sr
73
-
74
  with open(audio_path, 'rb') as f:
75
- file_hash = hashlib.md5(f.read(200000)).hexdigest()
76
- ck = ("stem", file_hash, stem, model_name, shifts, overlap)
77
- cached = cache_get(ck)
78
- if cached is not None:
79
- print(f"[Stage 1] Using cached {stem} stem"); return cached
80
-
81
  from demucs.pretrained import get_model
82
  from demucs.apply import apply_model
83
  print(f"[Stage 1] Extracting '{stem}' with {model_name}...")
84
- model = get_model(model_name); model.eval().to(device)
85
- sr = model.samplerate
86
- if stem not in model.sources:
87
- raise ValueError(f"'{stem}' not in {model.sources}")
88
- audio_np, _ = librosa.load(audio_path, sr=sr, mono=False)
89
- if audio_np.ndim == 1: audio_np = np.stack([audio_np, audio_np])
90
- elif audio_np.shape[0] > 2: audio_np = audio_np[:2]
91
- elif audio_np.shape[0] == 1: audio_np = np.concatenate([audio_np, audio_np], axis=0)
92
- wav = torch.from_numpy(audio_np).float().unsqueeze(0).to(device)
93
  with torch.no_grad():
94
- sources = apply_model(model, wav, device=device, shifts=shifts, split=True, overlap=overlap)
95
- result = sources[0, model.sources.index(stem)].mean(dim=0).cpu().numpy()
96
- print(f" ✓ {stem}: {len(result)/sr:.1f}s")
97
- return cache_set(ck, (result.astype(np.float32), sr))
98
 
99
 
100
  # ─── Stage 2: Onset detection ────────────────────────────────────────────────
101
 
102
- def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
103
- min_dur: float = 0.02, max_dur: float = 1.5,
104
- min_gap: float = 0.03, energy_threshold_db: float = -35.0,
105
- mode: str = "auto", backtrack: bool = True,
106
- onset_delta: float = 0.12) -> list:
107
- print(f"[Stage 2] Detecting onsets (mode={mode}, delta={onset_delta}, energy≥{energy_threshold_db}dB)...")
108
  if mode == "percussive":
109
- onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
110
  elif mode == "harmonic":
111
- y_harm, _ = librosa.effects.hpss(y)
112
- onset_env = librosa.onset.onset_strength(y=y_harm, sr=sr, fmax=8000, lag=2, max_size=3)
113
  elif mode == "broadband":
114
- onset_env = librosa.onset.onset_strength(y=y, sr=sr)
115
  else:
116
- y_harm, y_perc = librosa.effects.hpss(y)
117
- envs = [
118
- librosa.onset.onset_strength(y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median),
119
- librosa.onset.onset_strength(y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median),
120
- librosa.onset.onset_strength(y=y, sr=sr, fmin=4000, fmax=min(sr//2,20000), aggregate=np.median),
121
- librosa.onset.onset_strength(y=y_harm, sr=sr, lag=2),
122
- ]
123
  def _n(x): m=x.max(); return x/m if m>0 else x
124
- onset_env = np.maximum.reduce([_n(e) for e in envs])
125
-
126
- wait = max(1, int(min_gap * sr / 512))
127
- frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, wait=wait,
128
  pre_avg=3, post_avg=3, pre_max=3, post_max=5,
129
  delta=onset_delta, backtrack=backtrack, units='frames')
130
- times = librosa.frames_to_time(frames, sr=sr)
131
- print(f" Raw onsets: {len(times)}")
132
-
133
- threshold = 10 ** (energy_threshold_db / 20)
134
- hits = []
135
  for i, t in enumerate(times):
136
- s = max(0, int((t - pre_pad) * sr))
137
  e = min(int(times[i+1]*sr) if i+1<len(times) else len(y), s+int(max_dur*sr))
138
  seg = y[s:e]
139
- if len(seg) < int(min_dur*sr): continue
140
  rms = np.sqrt(np.mean(seg**2))
141
- if rms < threshold: continue
142
  fl = min(int(0.005*sr), len(seg)//4)
143
- if fl > 0: seg = seg.copy(); seg[-fl:] *= np.linspace(1, 0, fl)
144
- sc = float(librosa.feature.spectral_centroid(y=seg, sr=sr).mean())
145
- hits.append(Hit(audio=seg, sr=sr, onset_time=t, duration=len(seg)/sr,
146
- index=len(hits), rms_energy=float(rms), spectral_centroid=sc))
147
- print(f" ✓ Valid hits: {len(hits)}")
148
- return hits
149
 
150
 
151
  # ─── Stage 3: Classification ─────────────────────────────────────────────────
@@ -163,20 +148,19 @@ LABEL_RULES = [
163
  ("mid", lambda lr,mr,hr,c,zcr,d: c>800),
164
  ]
165
 
166
- def classify_hit(hit: Hit) -> str:
167
- y, sr = hit.audio, hit.sr
168
  D = np.abs(librosa.stft(y, n_fft=2048))
169
- freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
170
- le = np.sum(D[(freqs>=20)&(freqs<200)]**2)
171
- me = np.sum(D[(freqs>=200)&(freqs<4000)]**2)
172
- he = np.sum(D[(freqs>=4000)]**2)
173
- total = le+me+he+1e-10; lr,mr,hr = le/total,me/total,he/total
174
  zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
175
  for name, fn in LABEL_RULES:
176
  if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return name
177
  return "other"
178
 
179
- def classify_hits(hits: list) -> list:
180
  print(f"[Stage 3] Classifying {len(hits)} hits...")
181
  for h in hits: h.label = classify_hit(h)
182
  counts = defaultdict(int)
@@ -187,212 +171,167 @@ def classify_hits(hits: list) -> list:
187
 
188
  # ─── Stage 4: NCC-based clustering ───────────────────────────────────────────
189
 
190
- def ncc_max(a: np.ndarray, b: np.ndarray) -> float:
191
- """NCC peak. Amplitude-invariant. Compares the shorter length only."""
192
- # Use the shorter clip's length — no zero-padding inflation
193
- n = min(len(a), len(b))
194
- a, b = a[:n].copy(), b[:n].copy()
195
- a -= a.mean(); b -= b.mean()
196
- norm = np.sqrt(np.dot(a,a) * np.dot(b,b))
197
- if norm < 1e-10: return 0.0
198
- cc = fftconvolve(a, b[::-1], mode='full')
199
- return float(np.max(np.abs(cc))) / norm
200
-
201
 
202
- def build_ncc_distance_matrix(hits: list, max_compare_samples: int = 8820) -> np.ndarray:
203
- """Cached NCC distance matrix. Auto-scales compare window to hit lengths."""
204
  key = ("ncc_dist", tuple(_audio_hash(h.audio) for h in hits), max_compare_samples)
205
- cached = cache_get(key)
206
- if cached is not None:
207
- print(f" Using cached NCC distance matrix"); return cached
208
- N = len(hits)
209
- D = np.zeros((N, N), dtype=np.float32)
210
  for i in range(N):
211
- ai = hits[i].audio[:max_compare_samples]
212
- for j in range(i+1, N):
213
- bj = hits[j].audio[:max_compare_samples]
214
- D[i,j] = D[j,i] = max(0.0, 1.0 - ncc_max(ai, bj))
215
  return cache_set(key, D)
216
 
217
-
218
  def _labels_to_clusters(labels, hits):
219
- cluster_map = defaultdict(list)
220
- for i, lab in enumerate(labels): cluster_map[lab].append(i)
221
  clusters = []
222
- for _, indices in sorted(cluster_map.items()):
223
- votes = defaultdict(int)
224
- for idx in indices: votes[hits[idx].label] += 1
225
- majority = max(votes, key=votes.get)
226
- existing = sum(1 for c in clusters if c.label.rsplit('_',1)[0] == majority)
227
- clusters.append(Cluster(cluster_id=len(clusters), label=f"{majority}_{existing}",
228
- hits=[hits[i] for i in indices]))
229
  clusters.sort(key=lambda c: c.count, reverse=True)
230
- for i, c in enumerate(clusters): c.cluster_id = i
231
  return clusters
232
 
233
-
234
- def cluster_hits(hits: list, ncc_threshold: float = 0.80,
235
- max_compare_ms: float = 0,
236
- target_min: int = 0, target_max: int = 0,
237
- linkage: str = 'average') -> list:
238
- """NCC clustering with target range support.
239
-
240
- max_compare_ms: 0 = auto (use median hit length). Otherwise milliseconds.
241
- target_min/max: if both > 0, find a cluster count in this range.
242
- """
243
  from sklearn.cluster import AgglomerativeClustering
244
-
245
  if not hits: return []
246
- N = len(hits); sr = hits[0].sr
247
- if N == 1: return [Cluster(cluster_id=0, label=f"{hits[0].label}_0", hits=[hits[0]])]
248
-
249
- # Auto-scale compare window to median hit length
250
- if max_compare_ms <= 0:
251
- median_len = int(np.median([len(h.audio) for h in hits]))
252
- max_samples = max(int(0.03 * sr), median_len) # at least 30ms
253
- else:
254
- max_samples = int(max_compare_ms / 1000.0 * sr)
255
-
256
- print(f"[Stage 4] NCC clustering ({N} hits, compare={max_samples/sr*1000:.0f}ms, linkage={linkage})...")
257
- print(f" Computing {N*(N-1)//2} pairwise NCC distances...")
258
- D = build_ncc_distance_matrix(hits, max_compare_samples=max_samples)
259
-
260
- use_target = target_min > 0 and target_max > 0 and target_max >= target_min
261
- target_min = max(1, min(target_min, N))
262
- target_max = max(target_min, min(target_max, N))
263
-
264
- if use_target:
265
- print(f" Target: {target_min}–{target_max} clusters")
266
-
267
- # Strategy 1: Binary search on distance threshold
268
- lo, hi = 0.001, 1.0
269
- best_labels, best_n, best_dist = None, -1, 0.5
270
  for _ in range(30):
271
- mid = (lo + hi) / 2
272
- agg = AgglomerativeClustering(n_clusters=None, distance_threshold=max(0.001, mid),
273
- metric='precomputed', linkage=linkage)
274
- labels = agg.fit_predict(D)
275
- n = len(set(labels))
276
- if target_min <= n <= target_max:
277
- best_labels, best_n, best_dist = labels, n, mid; break
278
- elif n > target_max: lo = mid
279
- else: hi = mid
280
- if best_labels is None or abs(n-(target_min+target_max)/2) < abs(best_n-(target_min+target_max)/2):
281
- best_labels, best_n, best_dist = labels, n, mid
282
-
283
- # Strategy 2: If binary search didn't land in range, use n_clusters directly
284
- if best_n < target_min or best_n > target_max:
285
- target_mid = (target_min + target_max) // 2
286
- target_mid = min(target_mid, N - 1)
287
- print(f" Binary search got {best_n}, falling back to n_clusters={target_mid}")
288
  try:
289
- agg = AgglomerativeClustering(n_clusters=target_mid, metric='precomputed', linkage=linkage)
290
- best_labels = agg.fit_predict(D)
291
- best_n = target_mid
292
- except Exception as e:
293
- print(f" n_clusters fallback failed: {e}")
294
-
295
- labels = best_labels
296
- print(f" → {best_n} clusters (dist_threshold={best_dist:.4f})")
297
  else:
298
- dist_threshold = max(0.001, 1.0 - ncc_threshold)
299
- print(f" Fixed: NCC≥{ncc_threshold} (dist≤{dist_threshold:.3f})")
300
- agg = AgglomerativeClustering(n_clusters=None, distance_threshold=dist_threshold,
301
- metric='precomputed', linkage=linkage)
302
- labels = agg.fit_predict(D)
303
-
304
  print(f" ✓ {len(set(labels))} clusters")
305
- clusters = _labels_to_clusters(labels, hits)
306
- for c in clusters: print(f" {c.label}: {c.count} hits")
307
- return clusters
308
 
309
 
310
  # ─── Stage 5: Quality scoring ────────────────────────────────────────────────
311
 
312
  def sample_quality_score(y, sr, label="other"):
313
  import scipy.stats
314
- rms_env = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
315
- if len(rms_env) >= 10:
316
- pk = np.argmax(rms_env); post = rms_env[pk:]
317
- tail_r = np.mean(post[-max(3,len(post)//5):])/(rms_env[pk]+1e-8)
318
- c1 = max(0, 1.0-tail_r*5)
319
  if len(post)>=5:
320
- slope,_,r,_,_ = scipy.stats.linregress(np.arange(len(post)), np.log(post+1e-8))
321
- c2 = max(0,r**2) if slope<0 else r**2*0.3
322
  else: c2=0.0
323
- else: c1,c2 = 0.5,0.0
324
- completeness = c1*0.6+c2*0.4
325
- snr = 10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
326
- n_snr = np.clip((snr-10)/40,0,1)
327
- onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
328
- if len(onsets)>0:
329
- os_s=int(onsets[0]); pre=y[max(0,os_s-int(sr*.02)):os_s]; sig=y[os_s:os_s+int(sr*.1)]
330
- n_pre = np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) \
331
- if len(pre)>10 and len(sig)>10 else 0.5
332
- else: n_pre=0.5
333
- cleanness = n_snr*0.5+n_pre*0.5
334
- oe = librosa.onset.onset_strength(y=y, sr=sr)
335
- sharpness = float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0
336
- onset_q = float(np.clip((sharpness-1.0)/5.0,0,1))
337
- total = (completeness*0.30+cleanness*0.40+onset_q*0.20+0.5*0.10)*100
338
- return {'total':float(total),'completeness':float(completeness),
339
- 'cleanness':float(cleanness),'onset_quality':float(onset_q)}
340
 
341
  def select_best(clusters):
342
- print(f"[Stage 5] Selecting best representatives...")
343
  for c in clusters:
344
  if c.count<=1: c.best_hit_idx=0; continue
345
- scores = [sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]
346
- c.best_hit_idx = int(np.argmax(scores))
347
 
348
 
349
  # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
350
 
351
  def synthesize_from_cluster(cluster):
352
  if cluster.count<2: return None
353
- tl = int(np.median([len(h.audio) for h in cluster.hits]))
354
- aligned, weights = [], []
355
- pp_target = None
356
- for i, h in enumerate(cluster.hits):
357
- a = h.audio.copy(); pp = np.argmax(np.abs(a))
358
- if pp_target is None: pp_target = pp
359
- shift = pp_target-pp
360
- if shift>0: a=np.pad(a,(shift,0))
361
- elif shift<0: a=a[-shift:]
362
- a = a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a)))
363
- pk = np.abs(a).max()
364
  if pk>0: a=a/pk
365
- aligned.append(a); weights.append(2.0 if i==cluster.best_hit_idx else 1.0)
366
- aligned=np.array(aligned); w=np.array(weights); w/=w.sum()
367
- synth=np.average(aligned,axis=0,weights=w); pk=np.abs(synth).max()
368
- return (synth*0.95/pk).astype(np.float32) if pk>0 else synth.astype(np.float32)
369
 
370
 
371
  # ─── Stage 7: MIDI + rendering ───────────────────────────────────────────────
372
 
373
  def build_midi(clusters, bpm=120.0):
374
  import pretty_midi
375
- pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
376
  for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
377
- inst = pretty_midi.Instrument(program=0, is_drum=True, name='Extracted Samples')
378
  pm.instruments.append(inst)
379
  for c in clusters:
380
  for h in c.hits:
381
- vel=max(1,min(127,int(h.rms_energy/0.3*127)))
382
- inst.notes.append(pretty_midi.Note(velocity=vel,pitch=c.midi_note,
383
  start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
384
  inst.notes.sort(key=lambda n: n.start); return pm
385
 
386
- def export_midi(clusters, output_path, bpm=120.0):
387
- pm=build_midi(clusters,bpm); pm.write(output_path)
388
- print(f" ✓ MIDI: {output_path} ({len(pm.instruments[0].notes)} notes)"); return pm
389
 
390
  def detect_bpm(y, sr):
391
- ck=("bpm",_audio_hash(y),sr); cached=cache_get(ck)
392
- if cached is not None: print(f" Cached BPM: {cached}"); return cached
393
- onset_env=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
394
- bpm=float(librosa.feature.tempo(onset_envelope=onset_env,sr=sr).item())
395
- _,beats=librosa.beat.beat_track(onset_envelope=onset_env,sr=sr,units='time')
396
  if len(beats)>2:
397
  ibi=60.0/float(np.median(np.diff(beats)))
398
  for c in [bpm,ibi]:
@@ -403,16 +342,16 @@ def detect_bpm(y, sr):
403
  return cache_set(ck, round(bpm,1))
404
 
405
  def render_midi_with_samples(clusters, sr=44100):
406
- max_end=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
407
- buf=np.zeros(int((max_end+1.0)*sr),dtype=np.float64)
408
  for c in clusters:
409
- sample=c.best_hit.audio.astype(np.float64)
410
- ref_e=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
411
  for h in c.hits:
412
- vs=min(2.0,h.rms_energy/(ref_e+1e-8))**0.5
413
- s=int(h.onset_time*sr); e=s+len(sample)
414
  if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
415
- buf[s:e]+=sample*vs
416
  pk=np.abs(buf).max()
417
  return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
418
 
@@ -422,31 +361,180 @@ def build_sample_map(clusters):
422
 
423
  def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
424
  import zipfile, tempfile, io
425
- zip_path=tempfile.mktemp(suffix='.zip')
426
- index={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),
427
- 'total_hits':sum(c.count for c in clusters),'samples':{}}
428
- with zipfile.ZipFile(zip_path,'w',compression=zipfile.ZIP_STORED) as zf:
429
  for c in clusters:
430
- best=c.best_hit; fname=f"samples/{c.label}.wav"
431
- buf=io.BytesIO(); sf.write(buf,best.audio,sr,format='WAV',subtype='PCM_24')
432
- zf.writestr(fname,buf.getvalue())
433
- onset_times=sorted([h.onset_time for h in c.hits])
434
- index['samples'][c.label]={
435
- 'file':fname,'classification':c.label.rsplit('_',1)[0],
436
  'midi_note':c.midi_note,'occurrences':c.count,
437
- 'onset_times_sec':[round(t,4) for t in onset_times],
438
- 'duration_sec':round(best.duration,4),
439
- 'rms_energy':round(best.rms_energy,6),
440
- 'spectral_centroid_hz':round(best.spectral_centroid,1),
441
- }
442
  if c.synthesized is not None:
443
  sf2=f"samples/{c.label}__synthesized.wav"; b2=io.BytesIO()
444
  sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24')
445
- zf.writestr(sf2,b2.getvalue())
446
- index['samples'][c.label]['synthesized_file']=sf2
447
- zf.writestr('index.json',json.dumps(index,indent=2))
448
  if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
449
  if rendered_audio is not None:
450
  rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16')
451
  zf.writestr('rendered_reconstruction.wav',rb.getvalue())
452
- return zip_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Sample Extractor v7Auto-tuning via reconstruction quality.
4
 
5
+ New: auto_tune() optimizes parameters against the uploaded audio itself.
6
+ No ground truth needed measures spectral envelope correlation between
7
+ the rendered reconstruction and the original stem. Sweeps onset detection
8
+ params and cluster counts, using cached NCC matrices for speed.
 
9
  """
10
 
11
  import argparse, json, os, sys, warnings, hashlib
 
63
 
64
  # ─── Stage 1: Stem separation ────────────────────────────────────────────────
65
 
66
+ def extract_stem(audio_path, stem="drums", device="cpu",
67
+ model_name="htdemucs_ft", shifts=1, overlap=0.25):
 
68
  if stem == "all":
69
  y, sr = librosa.load(audio_path, sr=44100, mono=True)
70
  return y.astype(np.float32), sr
 
71
  with open(audio_path, 'rb') as f:
72
+ fh = hashlib.md5(f.read(200000)).hexdigest()
73
+ ck = ("stem", fh, stem, model_name, shifts, overlap)
74
+ c = cache_get(ck)
75
+ if c is not None: print(f"[Stage 1] Cached {stem} stem"); return c
 
 
76
  from demucs.pretrained import get_model
77
  from demucs.apply import apply_model
78
  print(f"[Stage 1] Extracting '{stem}' with {model_name}...")
79
+ model = get_model(model_name); model.eval().to(device); sr = model.samplerate
80
+ if stem not in model.sources: raise ValueError(f"'{stem}' not in {model.sources}")
81
+ a, _ = librosa.load(audio_path, sr=sr, mono=False)
82
+ if a.ndim==1: a=np.stack([a,a])
83
+ elif a.shape[0]>2: a=a[:2]
84
+ elif a.shape[0]==1: a=np.concatenate([a,a],axis=0)
85
+ wav = torch.from_numpy(a).float().unsqueeze(0).to(device)
 
 
86
  with torch.no_grad():
87
+ src = apply_model(model, wav, device=device, shifts=shifts, split=True, overlap=overlap)
88
+ r = src[0, model.sources.index(stem)].mean(dim=0).cpu().numpy()
89
+ print(f" ✓ {stem}: {len(r)/sr:.1f}s")
90
+ return cache_set(ck, (r.astype(np.float32), sr))
91
 
92
 
93
  # ─── Stage 2: Onset detection ────────────────────────────────────────────────
94
 
95
+ def detect_onsets(y, sr, pre_pad=0.005, min_dur=0.02, max_dur=1.5,
96
+ min_gap=0.03, energy_threshold_db=-35.0,
97
+ mode="auto", backtrack=True, onset_delta=0.12):
98
+ print(f"[Stage 2] Onsets (mode={mode}, delta={onset_delta}, energy≥{energy_threshold_db}dB)...")
 
 
99
  if mode == "percussive":
100
+ oe = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
101
  elif mode == "harmonic":
102
+ yh, _ = librosa.effects.hpss(y)
103
+ oe = librosa.onset.onset_strength(y=yh, sr=sr, fmax=8000, lag=2, max_size=3)
104
  elif mode == "broadband":
105
+ oe = librosa.onset.onset_strength(y=y, sr=sr)
106
  else:
107
+ yh, yp = librosa.effects.hpss(y)
108
+ envs = [librosa.onset.onset_strength(y=y,sr=sr,fmin=20,fmax=250,aggregate=np.median),
109
+ librosa.onset.onset_strength(y=y,sr=sr,fmin=250,fmax=4000,aggregate=np.median),
110
+ librosa.onset.onset_strength(y=y,sr=sr,fmin=4000,fmax=min(sr//2,20000),aggregate=np.median),
111
+ librosa.onset.onset_strength(y=yh,sr=sr,lag=2)]
 
 
112
  def _n(x): m=x.max(); return x/m if m>0 else x
113
+ oe = np.maximum.reduce([_n(e) for e in envs])
114
+ w = max(1, int(min_gap*sr/512))
115
+ fr = librosa.onset.onset_detect(onset_envelope=oe, sr=sr, wait=w,
 
116
  pre_avg=3, post_avg=3, pre_max=3, post_max=5,
117
  delta=onset_delta, backtrack=backtrack, units='frames')
118
+ times = librosa.frames_to_time(fr, sr=sr)
119
+ print(f" Raw: {len(times)}")
120
+ thr = 10**(energy_threshold_db/20); hits = []
 
 
121
  for i, t in enumerate(times):
122
+ s = max(0, int((t-pre_pad)*sr))
123
  e = min(int(times[i+1]*sr) if i+1<len(times) else len(y), s+int(max_dur*sr))
124
  seg = y[s:e]
125
+ if len(seg)<int(min_dur*sr): continue
126
  rms = np.sqrt(np.mean(seg**2))
127
+ if rms<thr: continue
128
  fl = min(int(0.005*sr), len(seg)//4)
129
+ if fl>0: seg=seg.copy(); seg[-fl:]*=np.linspace(1,0,fl)
130
+ sc = float(librosa.feature.spectral_centroid(y=seg,sr=sr).mean())
131
+ hits.append(Hit(audio=seg,sr=sr,onset_time=t,duration=len(seg)/sr,
132
+ index=len(hits),rms_energy=float(rms),spectral_centroid=sc))
133
+ print(f" ✓ {len(hits)} hits"); return hits
 
134
 
135
 
136
  # ─── Stage 3: Classification ─────────────────────────────────────────────────
 
148
  ("mid", lambda lr,mr,hr,c,zcr,d: c>800),
149
  ]
150
 
151
+ def classify_hit(hit):
152
+ y,sr = hit.audio, hit.sr
153
  D = np.abs(librosa.stft(y, n_fft=2048))
154
+ f = librosa.fft_frequencies(sr=sr, n_fft=2048)
155
+ le=np.sum(D[(f>=20)&(f<200)]**2); me=np.sum(D[(f>=200)&(f<4000)]**2)
156
+ he=np.sum(D[(f>=4000)]**2); t=le+me+he+1e-10
157
+ lr,mr,hr = le/t,me/t,he/t
 
158
  zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
159
  for name, fn in LABEL_RULES:
160
  if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return name
161
  return "other"
162
 
163
+ def classify_hits(hits):
164
  print(f"[Stage 3] Classifying {len(hits)} hits...")
165
  for h in hits: h.label = classify_hit(h)
166
  counts = defaultdict(int)
 
171
 
172
  # ─── Stage 4: NCC-based clustering ───────────────────────────────────────────
173
 
174
+ def ncc_max(a, b):
175
+ n = min(len(a), len(b)); a,b = a[:n].copy(), b[:n].copy()
176
+ a-=a.mean(); b-=b.mean()
177
+ norm = np.sqrt(np.dot(a,a)*np.dot(b,b))
178
+ if norm<1e-10: return 0.0
179
+ return float(np.max(np.abs(fftconvolve(a,b[::-1],mode='full'))))/norm
 
 
 
 
 
180
 
181
+ def build_ncc_distance_matrix(hits, max_compare_samples=8820):
 
182
  key = ("ncc_dist", tuple(_audio_hash(h.audio) for h in hits), max_compare_samples)
183
+ c = cache_get(key)
184
+ if c is not None: print(f" Cached NCC matrix"); return c
185
+ N=len(hits); D=np.zeros((N,N),dtype=np.float32)
 
 
186
  for i in range(N):
187
+ ai=hits[i].audio[:max_compare_samples]
188
+ for j in range(i+1,N):
189
+ D[i,j]=D[j,i]=max(0.0, 1.0-ncc_max(ai,hits[j].audio[:max_compare_samples]))
 
190
  return cache_set(key, D)
191
 
 
192
  def _labels_to_clusters(labels, hits):
193
+ cm = defaultdict(list)
194
+ for i,l in enumerate(labels): cm[l].append(i)
195
  clusters = []
196
+ for _, idx in sorted(cm.items()):
197
+ v = defaultdict(int)
198
+ for i in idx: v[hits[i].label]+=1
199
+ maj = max(v, key=v.get)
200
+ ex = sum(1 for c in clusters if c.label.rsplit('_',1)[0]==maj)
201
+ clusters.append(Cluster(cluster_id=len(clusters), label=f"{maj}_{ex}",
202
+ hits=[hits[i] for i in idx]))
203
  clusters.sort(key=lambda c: c.count, reverse=True)
204
+ for i,c in enumerate(clusters): c.cluster_id=i
205
  return clusters
206
 
207
+ def cluster_hits(hits, ncc_threshold=0.80, max_compare_ms=0,
208
+ target_min=0, target_max=0, linkage='average'):
 
 
 
 
 
 
 
 
209
  from sklearn.cluster import AgglomerativeClustering
 
210
  if not hits: return []
211
+ N=len(hits); sr=hits[0].sr
212
+ if N==1: return [Cluster(cluster_id=0, label=f"{hits[0].label}_0", hits=[hits[0]])]
213
+ if max_compare_ms<=0:
214
+ ms = max(int(0.03*sr), int(np.median([len(h.audio) for h in hits])))
215
+ else: ms = int(max_compare_ms/1000.0*sr)
216
+ print(f"[Stage 4] NCC clustering ({N} hits, {ms/sr*1000:.0f}ms, {linkage})...")
217
+ print(f" Computing {N*(N-1)//2} pairwise distances...")
218
+ D = build_ncc_distance_matrix(hits, max_compare_samples=ms)
219
+ use_t = target_min>0 and target_max>0 and target_max>=target_min
220
+ tmin=max(1,min(target_min or 1,N)); tmax=max(tmin,min(target_max or N,N))
221
+ if use_t:
222
+ print(f" Target: {tmin}–{tmax}")
223
+ lo,hi=0.001,1.0; bl,bn,bd=None,-1,0.5
 
 
 
 
 
 
 
 
 
 
 
224
  for _ in range(30):
225
+ mid=(lo+hi)/2
226
+ agg=AgglomerativeClustering(n_clusters=None,distance_threshold=max(0.001,mid),
227
+ metric='precomputed',linkage=linkage)
228
+ lb=agg.fit_predict(D); n=len(set(lb))
229
+ if tmin<=n<=tmax: bl,bn,bd=lb,n,mid; break
230
+ elif n>tmax: lo=mid
231
+ else: hi=mid
232
+ if bl is None or abs(n-(tmin+tmax)/2)<abs(bn-(tmin+tmax)/2): bl,bn,bd=lb,n,mid
233
+ if bn<tmin or bn>tmax:
234
+ tm=min((tmin+tmax)//2, N-1)
235
+ print(f" Fallback n_clusters={tm}")
 
 
 
 
 
 
236
  try:
237
+ agg=AgglomerativeClustering(n_clusters=tm,metric='precomputed',linkage=linkage)
238
+ bl=agg.fit_predict(D); bn=tm
239
+ except: pass
240
+ labels=bl; print(f" → {bn} clusters")
 
 
 
 
241
  else:
242
+ dt=max(0.001, 1.0-ncc_threshold); print(f" Fixed: dist≤{dt:.3f}")
243
+ labels=AgglomerativeClustering(n_clusters=None,distance_threshold=dt,
244
+ metric='precomputed',linkage=linkage).fit_predict(D)
 
 
 
245
  print(f" ✓ {len(set(labels))} clusters")
246
+ cl = _labels_to_clusters(labels, hits)
247
+ for c in cl: print(f" {c.label}: {c.count} hits")
248
+ return cl
249
 
250
 
251
  # ─── Stage 5: Quality scoring ────────────────────────────────────────────────
252
 
253
  def sample_quality_score(y, sr, label="other"):
254
  import scipy.stats
255
+ rms_env=librosa.feature.rms(y=y,frame_length=512,hop_length=128)[0]
256
+ if len(rms_env)>=10:
257
+ pk=np.argmax(rms_env); post=rms_env[pk:]
258
+ c1=max(0,1.0-np.mean(post[-max(3,len(post)//5):])/( rms_env[pk]+1e-8)*5)
 
259
  if len(post)>=5:
260
+ sl,_,r,_,_=scipy.stats.linregress(np.arange(len(post)),np.log(post+1e-8))
261
+ c2=max(0,r**2) if sl<0 else r**2*0.3
262
  else: c2=0.0
263
+ else: c1,c2=0.5,0.0
264
+ comp=c1*0.6+c2*0.4
265
+ snr=10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
266
+ ns=np.clip((snr-10)/40,0,1)
267
+ ons=librosa.onset.onset_detect(y=y,sr=sr,units='samples',backtrack=True)
268
+ if len(ons)>0:
269
+ o=int(ons[0]); pre=y[max(0,o-int(sr*.02)):o]; sig=y[o:o+int(sr*.1)]
270
+ np2=np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) \
271
+ if len(pre)>10 and len(sig)>10 else 0.5
272
+ else: np2=0.5
273
+ clean=ns*0.5+np2*0.5
274
+ oe=librosa.onset.onset_strength(y=y,sr=sr)
275
+ sh=float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0
276
+ oq=float(np.clip((sh-1.0)/5.0,0,1))
277
+ tot=(comp*0.30+clean*0.40+oq*0.20+0.5*0.10)*100
278
+ return {'total':float(tot),'completeness':float(comp),'cleanness':float(clean),'onset_quality':float(oq)}
 
279
 
280
  def select_best(clusters):
281
+ print(f"[Stage 5] Selecting best...")
282
  for c in clusters:
283
  if c.count<=1: c.best_hit_idx=0; continue
284
+ sc=[sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]
285
+ c.best_hit_idx=int(np.argmax(sc))
286
 
287
 
288
  # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
289
 
290
  def synthesize_from_cluster(cluster):
291
  if cluster.count<2: return None
292
+ tl=int(np.median([len(h.audio) for h in cluster.hits]))
293
+ al,wt=[],[]
294
+ pp=None
295
+ for i,h in enumerate(cluster.hits):
296
+ a=h.audio.copy(); p=np.argmax(np.abs(a))
297
+ if pp is None: pp=p
298
+ s=pp-p
299
+ if s>0: a=np.pad(a,(s,0))
300
+ elif s<0: a=a[-s:]
301
+ a=a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a)))
302
+ pk=np.abs(a).max()
303
  if pk>0: a=a/pk
304
+ al.append(a); wt.append(2.0 if i==cluster.best_hit_idx else 1.0)
305
+ al=np.array(al); w=np.array(wt); w/=w.sum()
306
+ sy=np.average(al,axis=0,weights=w); pk=np.abs(sy).max()
307
+ return (sy*0.95/pk).astype(np.float32) if pk>0 else sy.astype(np.float32)
308
 
309
 
310
  # ─── Stage 7: MIDI + rendering ───────────────────────────────────────────────
311
 
312
  def build_midi(clusters, bpm=120.0):
313
  import pretty_midi
314
+ pm=pretty_midi.PrettyMIDI(initial_tempo=bpm)
315
  for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
316
+ inst=pretty_midi.Instrument(program=0,is_drum=True,name='Extracted Samples')
317
  pm.instruments.append(inst)
318
  for c in clusters:
319
  for h in c.hits:
320
+ v=max(1,min(127,int(h.rms_energy/0.3*127)))
321
+ inst.notes.append(pretty_midi.Note(velocity=v,pitch=c.midi_note,
322
  start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
323
  inst.notes.sort(key=lambda n: n.start); return pm
324
 
325
+ def export_midi(clusters, path, bpm=120.0):
326
+ pm=build_midi(clusters,bpm); pm.write(path)
327
+ print(f" ✓ MIDI: {path} ({len(pm.instruments[0].notes)} notes)"); return pm
328
 
329
  def detect_bpm(y, sr):
330
+ ck=("bpm",_audio_hash(y),sr); c=cache_get(ck)
331
+ if c is not None: return c
332
+ oe=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
333
+ bpm=float(librosa.feature.tempo(onset_envelope=oe,sr=sr).item())
334
+ _,beats=librosa.beat.beat_track(onset_envelope=oe,sr=sr,units='time')
335
  if len(beats)>2:
336
  ibi=60.0/float(np.median(np.diff(beats)))
337
  for c in [bpm,ibi]:
 
342
  return cache_set(ck, round(bpm,1))
343
 
344
  def render_midi_with_samples(clusters, sr=44100):
345
+ me=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
346
+ buf=np.zeros(int((me+1.0)*sr),dtype=np.float64)
347
  for c in clusters:
348
+ s=c.best_hit.audio.astype(np.float64)
349
+ re=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
350
  for h in c.hits:
351
+ vs=min(2.0,h.rms_energy/(re+1e-8))**0.5
352
+ i=int(h.onset_time*sr); e=i+len(s)
353
  if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
354
+ buf[i:e]+=s*vs
355
  pk=np.abs(buf).max()
356
  return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
357
 
 
361
 
362
  def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
363
  import zipfile, tempfile, io
364
+ zp=tempfile.mktemp(suffix='.zip')
365
+ idx={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),
366
+ 'total_hits':sum(c.count for c in clusters),'samples':{}}
367
+ with zipfile.ZipFile(zp,'w',compression=zipfile.ZIP_STORED) as zf:
368
  for c in clusters:
369
+ b=c.best_hit; fn=f"samples/{c.label}.wav"
370
+ buf=io.BytesIO(); sf.write(buf,b.audio,sr,format='WAV',subtype='PCM_24')
371
+ zf.writestr(fn,buf.getvalue())
372
+ ot=sorted([h.onset_time for h in c.hits])
373
+ idx['samples'][c.label]={
374
+ 'file':fn,'classification':c.label.rsplit('_',1)[0],
375
  'midi_note':c.midi_note,'occurrences':c.count,
376
+ 'onset_times_sec':[round(t,4) for t in ot],
377
+ 'duration_sec':round(b.duration,4),
378
+ 'rms_energy':round(b.rms_energy,6),
379
+ 'spectral_centroid_hz':round(b.spectral_centroid,1)}
 
380
  if c.synthesized is not None:
381
  sf2=f"samples/{c.label}__synthesized.wav"; b2=io.BytesIO()
382
  sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24')
383
+ zf.writestr(sf2,b2.getvalue()); idx['samples'][c.label]['synthesized_file']=sf2
384
+ zf.writestr('index.json',json.dumps(idx,indent=2))
 
385
  if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
386
  if rendered_audio is not None:
387
  rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16')
388
  zf.writestr('rendered_reconstruction.wav',rb.getvalue())
389
+ return zp
390
+
391
+
392
+ # ─── Auto-tuner ──────────────────────────────────────────────────────────────
393
+
394
+ def _spectral_envelope_corr(original, rendered, sr, n_fft=4096, hop=2048):
395
+ """Spectral envelope correlation between two signals. Phase-insensitive.
396
+ Returns correlation coefficient [−1, 1]. Higher = better reconstruction."""
397
+ n = min(len(original), len(rendered))
398
+ if n < n_fft: return 0.0
399
+ S_orig = np.abs(librosa.stft(original[:n], n_fft=n_fft, hop_length=hop))
400
+ S_rend = np.abs(librosa.stft(rendered[:n], n_fft=n_fft, hop_length=hop))
401
+ # Average over time to get spectral envelope
402
+ env_o = S_orig.mean(axis=1)
403
+ env_r = S_rend.mean(axis=1)
404
+ if env_o.std() < 1e-10 or env_r.std() < 1e-10: return 0.0
405
+ return float(np.corrcoef(env_o, env_r)[0, 1])
406
+
407
+
408
+ def _rms_envelope_corr(original, rendered, sr, hop=1024):
409
+ """RMS amplitude envelope correlation. Measures timing/dynamics match."""
410
+ n = min(len(original), len(rendered))
411
+ if n < hop * 4: return 0.0
412
+ rms_o = librosa.feature.rms(y=original[:n], hop_length=hop)[0]
413
+ rms_r = librosa.feature.rms(y=rendered[:n], hop_length=hop)[0]
414
+ n2 = min(len(rms_o), len(rms_r))
415
+ if n2 < 4: return 0.0
416
+ rms_o, rms_r = rms_o[:n2], rms_r[:n2]
417
+ if rms_o.std() < 1e-10 or rms_r.std() < 1e-10: return 0.0
418
+ return float(np.corrcoef(rms_o, rms_r)[0, 1])
419
+
420
+
421
+ def _reconstruction_score(original, rendered, sr):
422
+ """Combined score [0, 100] measuring how well the reconstruction matches."""
423
+ spec_corr = _spectral_envelope_corr(original, rendered, sr)
424
+ rms_corr = _rms_envelope_corr(original, rendered, sr)
425
+ # Penalize if reconstruction is much shorter/longer
426
+ len_ratio = min(len(rendered), len(original)) / (max(len(rendered), len(original)) + 1)
427
+ score = (spec_corr * 0.5 + rms_corr * 0.4 + len_ratio * 0.1) * 100
428
+ return max(0.0, score)
429
+
430
+
431
+ def auto_tune(stem_audio, sr, mode="auto", log_fn=None):
432
+ """Automatically find the best extraction parameters for this audio.
433
+
434
+ Sweeps onset detection params and cluster counts. Uses the cached NCC matrix
435
+ so re-clustering is near-instant after the first onset detection.
436
+
437
+ Returns: (best_params: dict, best_score: float, log: list[str])
438
+
439
+ The returned params can be passed directly to the extraction pipeline.
440
+ """
441
+ log = []
442
+ def _log(msg):
443
+ log.append(msg)
444
+ if log_fn: log_fn(msg)
445
+ print(msg)
446
+
447
+ _log(f"[Auto-tune] Starting parameter search on {len(stem_audio)/sr:.1f}s audio...")
448
+
449
+ # Parameter grid — coarse sweep first
450
+ onset_configs = [
451
+ {"onset_delta": 0.08, "energy_threshold_db": -40, "min_gap": 0.02},
452
+ {"onset_delta": 0.10, "energy_threshold_db": -35, "min_gap": 0.025},
453
+ {"onset_delta": 0.12, "energy_threshold_db": -35, "min_gap": 0.03},
454
+ {"onset_delta": 0.15, "energy_threshold_db": -30, "min_gap": 0.04},
455
+ {"onset_delta": 0.20, "energy_threshold_db": -28, "min_gap": 0.05},
456
+ {"onset_delta": 0.25, "energy_threshold_db": -25, "min_gap": 0.06},
457
+ ]
458
+
459
+ cluster_targets = [3, 5, 8, 10, 15, 20, 30]
460
+
461
+ best_score = -1
462
+ best_params = {}
463
+ best_clusters = None
464
+ results = []
465
+
466
+ for oc_idx, oc in enumerate(onset_configs):
467
+ _log(f"\n Config {oc_idx+1}/{len(onset_configs)}: delta={oc['onset_delta']}, "
468
+ f"energy={oc['energy_threshold_db']}dB, gap={oc['min_gap']}s")
469
+
470
+ hits = detect_onsets(stem_audio, sr, mode=mode, **oc)
471
+ if len(hits) < 2:
472
+ _log(f" → Only {len(hits)} hits, skipping")
473
+ continue
474
+
475
+ hits = classify_hits(hits)
476
+
477
+ # NCC matrix computed once per onset config (cached within)
478
+ for n_target in cluster_targets:
479
+ if n_target >= len(hits): continue
480
+
481
+ clusters = cluster_hits(hits, target_min=n_target, target_max=n_target,
482
+ linkage='average')
483
+ if not clusters: continue
484
+
485
+ # Quick select best + render
486
+ for c in clusters:
487
+ if c.count <= 1: c.best_hit_idx = 0
488
+ else:
489
+ energies = [h.rms_energy for h in c.hits]
490
+ c.best_hit_idx = int(np.argmax(energies)) # fast: pick loudest
491
+
492
+ rendered = render_midi_with_samples(clusters, sr=sr)
493
+ score = _reconstruction_score(stem_audio, rendered, sr)
494
+
495
+ results.append({**oc, 'n_clusters': len(clusters),
496
+ 'target': n_target, 'n_hits': len(hits), 'score': score})
497
+
498
+ if score > best_score:
499
+ best_score = score
500
+ best_params = {**oc, 'n_clusters': len(clusters), 'target_min': n_target,
501
+ 'target_max': n_target}
502
+ best_clusters = clusters
503
+ _log(f" ★ target={n_target} → {len(clusters)} clusters, "
504
+ f"score={score:.1f} (NEW BEST)")
505
+ else:
506
+ _log(f" target={n_target} → {len(clusters)} clusters, score={score:.1f}")
507
+
508
+ # Fine-tune: try ±1 around best target with best onset config
509
+ if best_params:
510
+ bt = best_params.get('target_min', 10)
511
+ _log(f"\n Fine-tuning around best (delta={best_params['onset_delta']}, "
512
+ f"target≈{bt})...")
513
+
514
+ fine_oc = {k: best_params[k] for k in ['onset_delta', 'energy_threshold_db', 'min_gap']}
515
+ hits = detect_onsets(stem_audio, sr, mode=mode, **fine_oc)
516
+ if len(hits) >= 2:
517
+ hits = classify_hits(hits)
518
+ for ft in range(max(2, bt-3), bt+4):
519
+ if ft >= len(hits): continue
520
+ clusters = cluster_hits(hits, target_min=ft, target_max=ft, linkage='average')
521
+ if not clusters: continue
522
+ for c in clusters:
523
+ if c.count <= 1: c.best_hit_idx = 0
524
+ else: c.best_hit_idx = int(np.argmax([h.rms_energy for h in c.hits]))
525
+ rendered = render_midi_with_samples(clusters, sr=sr)
526
+ score = _reconstruction_score(stem_audio, rendered, sr)
527
+ if score > best_score:
528
+ best_score = score
529
+ best_params = {**fine_oc, 'n_clusters': len(clusters),
530
+ 'target_min': ft, 'target_max': ft}
531
+ best_clusters = clusters
532
+ _log(f" ★ target={ft} → {len(clusters)} clusters, "
533
+ f"score={score:.1f} (NEW BEST)")
534
+
535
+ _log(f"\n[Auto-tune] Best: score={best_score:.1f}, "
536
+ f"delta={best_params.get('onset_delta')}, "
537
+ f"energy={best_params.get('energy_threshold_db')}dB, "
538
+ f"clusters={best_params.get('n_clusters')}")
539
+
540
+ return best_params, best_score, log