rikhoffbauer2 commited on
Commit
d1fa59c
·
verified ·
1 Parent(s): 539cac7

v6: Fix clustering for real music — auto-scale NCC window, n_clusters fallback, better defaults

Browse files
Files changed (1) hide show
  1. sample_extractor.py +226 -337
sample_extractor.py CHANGED
@@ -1,23 +1,15 @@
1
  #!/usr/bin/env python3
2
  """
3
- Sample Extractor v4NCC-based clustering, full parameter control.
4
-
5
- Key fix: Uses normalized cross-correlation (NCC) to detect identical samples
6
- instead of MFCC-based KMeans. NCC is amplitude-invariant same kick at
7
- different velocities NCC 1.0. This correctly groups repeated occurrences
8
- of the same sample into one cluster.
9
-
10
- Stages:
11
- 1. STEM SEPARATION — Demucs (configurable model) isolates target stem
12
- 2. ONSET DETECTION — Adaptive multi-method detection
13
- 3. CLASSIFICATION — Spectral profile labeling (post-overlap-separation aware)
14
- 4. NCC CLUSTERING — Waveform identity matching via cross-correlation
15
- 5. QUALITY SCORING — Completeness + cleanness + onset sharpness
16
- 6. SYNTHESIS — Peak-aligned weighted average
17
- 7. MIDI + RENDER — Timeline reconstruction as .mid and .wav
18
  """
19
 
20
- import argparse, json, os, sys, warnings
21
  from collections import defaultdict
22
  from dataclasses import dataclass, field
23
  from pathlib import Path
@@ -49,86 +41,70 @@ class Cluster:
49
  def count(self) -> int: return len(self.hits)
50
 
51
 
52
- # Available Demucs models
53
- DEMUCS_MODELS = [
54
- "htdemucs", # Default hybrid transformer, 4-stem, fast
55
- "htdemucs_ft", # Fine-tuned, 4-stem, best quality, slower (bag of 4)
56
- "htdemucs_6s", # 6-stem: adds guitar + piano
57
- "mdx", # MDX competition winner, waveform U-Net
58
- "mdx_extra", # Hybrid spectral, highest quality overall
59
- "mdx_extra_q", # Quantized mdx_extra (needs diffq)
60
- ]
61
-
62
  DEMUCS_STEMS = {
63
- "htdemucs": ["drums", "bass", "other", "vocals"],
64
- "htdemucs_ft": ["drums", "bass", "other", "vocals"],
65
- "htdemucs_6s": ["drums", "bass", "other", "vocals", "guitar", "piano"],
66
- "mdx": ["drums", "bass", "other", "vocals"],
67
- "mdx_extra": ["drums", "bass", "other", "vocals"],
68
- "mdx_extra_q": ["drums", "bass", "other", "vocals"],
69
  }
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ─── Stage 1: Stem separation ────────────────────────────────────────────────
73
 
74
  def extract_stem(audio_path: str, stem: str = "drums", device: str = "cpu",
75
  model_name: str = "htdemucs_ft", shifts: int = 1,
76
  overlap: float = 0.25) -> tuple:
77
- """Extract a stem using Demucs. Cached by (file_hash, stem, model, shifts, overlap)."""
78
  if stem == "all":
79
  y, sr = librosa.load(audio_path, sr=44100, mono=True)
80
  return y.astype(np.float32), sr
81
 
82
- # Cache key from file content hash + params
83
  with open(audio_path, 'rb') as f:
84
- file_hash = hashlib.md5(f.read(200000)).hexdigest() # hash first 200KB
85
- cache_key = ("stem", file_hash, stem, model_name, shifts, overlap)
86
- cached = cache_get(cache_key)
87
  if cached is not None:
88
- print(f"[Stage 1] Using cached {stem} stem ({model_name})")
89
- return cached
90
 
91
  from demucs.pretrained import get_model
92
  from demucs.apply import apply_model
93
-
94
- print(f"[Stage 1] Extracting '{stem}' with {model_name} (shifts={shifts}, overlap={overlap})...")
95
- model = get_model(model_name)
96
- model.eval().to(device)
97
  sr = model.samplerate
98
-
99
  if stem not in model.sources:
100
- raise ValueError(f"Stem '{stem}' not in model '{model_name}'. Available: {model.sources}")
101
-
102
  audio_np, _ = librosa.load(audio_path, sr=sr, mono=False)
103
- if audio_np.ndim == 1:
104
- audio_np = np.stack([audio_np, audio_np])
105
- elif audio_np.shape[0] > 2:
106
- audio_np = audio_np[:2]
107
- elif audio_np.shape[0] == 1:
108
- audio_np = np.concatenate([audio_np, audio_np], axis=0)
109
-
110
  wav = torch.from_numpy(audio_np).float().unsqueeze(0).to(device)
111
  with torch.no_grad():
112
- sources = apply_model(model, wav, device=device, shifts=shifts,
113
- split=True, overlap=overlap)
114
-
115
- idx = model.sources.index(stem)
116
- result = sources[0, idx].mean(dim=0).cpu().numpy()
117
  print(f" ✓ {stem}: {len(result)/sr:.1f}s")
118
- out = (result.astype(np.float32), sr)
119
- return cache_set(cache_key, out)
120
 
121
 
122
  # ─── Stage 2: Onset detection ────────────────────────────────────────────────
123
 
124
  def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
125
  min_dur: float = 0.02, max_dur: float = 1.5,
126
- min_gap: float = 0.015, energy_threshold_db: float = -45.0,
127
  mode: str = "auto", backtrack: bool = True,
128
- onset_delta: float = 0.07) -> list:
129
- """Detect onsets. mode: auto|percussive|harmonic|broadband"""
130
- print(f"[Stage 2] Detecting onsets (mode={mode}, delta={onset_delta})...")
131
-
132
  if mode == "percussive":
133
  onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
134
  elif mode == "harmonic":
@@ -136,21 +112,19 @@ def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
136
  onset_env = librosa.onset.onset_strength(y=y_harm, sr=sr, fmax=8000, lag=2, max_size=3)
137
  elif mode == "broadband":
138
  onset_env = librosa.onset.onset_strength(y=y, sr=sr)
139
- else: # auto
140
  y_harm, y_perc = librosa.effects.hpss(y)
141
  envs = [
142
  librosa.onset.onset_strength(y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median),
143
  librosa.onset.onset_strength(y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median),
144
- librosa.onset.onset_strength(y=y, sr=sr, fmin=4000, fmax=min(sr//2, 20000), aggregate=np.median),
145
  librosa.onset.onset_strength(y=y_harm, sr=sr, lag=2),
146
  ]
147
- def _n(x):
148
- m = x.max(); return x/m if m > 0 else x
149
  onset_env = np.maximum.reduce([_n(e) for e in envs])
150
 
151
  wait = max(1, int(min_gap * sr / 512))
152
- frames = librosa.onset.onset_detect(
153
- onset_envelope=onset_env, sr=sr, wait=wait,
154
  pre_avg=3, post_avg=3, pre_max=3, post_max=5,
155
  delta=onset_delta, backtrack=backtrack, units='frames')
156
  times = librosa.frames_to_time(frames, sr=sr)
@@ -160,17 +134,13 @@ def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
160
  hits = []
161
  for i, t in enumerate(times):
162
  s = max(0, int((t - pre_pad) * sr))
163
- if i + 1 < len(times):
164
- e = min(int(times[i+1] * sr), s + int(max_dur * sr))
165
- else:
166
- e = min(len(y), s + int(max_dur * sr))
167
  seg = y[s:e]
168
- if len(seg) < int(min_dur * sr): continue
169
  rms = np.sqrt(np.mean(seg**2))
170
  if rms < threshold: continue
171
- fl = min(int(0.005 * sr), len(seg) // 4)
172
- if fl > 0:
173
- seg = seg.copy(); seg[-fl:] *= np.linspace(1, 0, fl)
174
  sc = float(librosa.feature.spectral_centroid(y=seg, sr=sr).mean())
175
  hits.append(Hit(audio=seg, sr=sr, onset_time=t, duration=len(seg)/sr,
176
  index=len(hits), rms_energy=float(rms), spectral_centroid=sc))
@@ -181,294 +151,221 @@ def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
181
  # ─── Stage 3: Classification ───────────────────────────────────��─────────────
182
 
183
  LABEL_RULES = [
184
- ("kick", lambda lr, mr, hr, c, zcr, d: lr > 0.5 and c < 800),
185
- ("hihat_closed", lambda lr, mr, hr, c, zcr, d: hr > 0.35 and c > 4000 and d < 0.15),
186
- ("hihat_open", lambda lr, mr, hr, c, zcr, d: hr > 0.35 and c > 4000 and d >= 0.15),
187
- ("cymbal", lambda lr, mr, hr, c, zcr, d: hr > 0.25 and c > 3000),
188
- ("snare", lambda lr, mr, hr, c, zcr, d: mr > 0.4 and zcr > 0.1 and c > 1000),
189
- ("tom", lambda lr, mr, hr, c, zcr, d: lr > 0.3 and mr > 0.3 and c < 1500),
190
- ("bass", lambda lr, mr, hr, c, zcr, d: lr > 0.6 and c < 400 and d > 0.2),
191
- ("vocal", lambda lr, mr, hr, c, zcr, d: mr > 0.5 and c > 500 and c < 3000 and zcr < 0.15),
192
- ("bright", lambda lr, mr, hr, c, zcr, d: c > 2500),
193
- ("mid", lambda lr, mr, hr, c, zcr, d: c > 800),
194
  ]
195
 
196
  def classify_hit(hit: Hit) -> str:
197
  y, sr = hit.audio, hit.sr
198
  D = np.abs(librosa.stft(y, n_fft=2048))
199
  freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
200
- le = np.sum(D[(freqs >= 20) & (freqs < 200)]**2)
201
- me = np.sum(D[(freqs >= 200) & (freqs < 4000)]**2)
202
- he = np.sum(D[(freqs >= 4000)]**2)
203
- total = le + me + he + 1e-10
204
- lr, mr, hr = le/total, me/total, he/total
205
  zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
206
  for name, fn in LABEL_RULES:
207
- if fn(lr, mr, hr, hit.spectral_centroid, zcr, hit.duration):
208
- return name
209
  return "other"
210
 
211
  def classify_hits(hits: list) -> list:
212
- """Classify all hits. No overlap separation — clustering handles grouping."""
213
  print(f"[Stage 3] Classifying {len(hits)} hits...")
214
- for h in hits:
215
- h.label = classify_hit(h)
216
  counts = defaultdict(int)
217
  for h in hits: counts[h.label] += 1
218
- for l, c in sorted(counts.items(), key=lambda x: -x[1]):
219
- print(f" {l}: {c}")
220
  return hits
221
 
222
 
223
- # ─── Caching ──────────────────────────────────────────────────────────────────
224
-
225
- import hashlib, functools
226
-
227
- _cache = {} # key → value; cleared per-session or manually
228
-
229
- def _audio_hash(audio: np.ndarray) -> str:
230
- """Fast hash of audio array for cache keys."""
231
- return hashlib.md5(audio[:4000].tobytes()).hexdigest()
232
-
233
- def cache_get(key):
234
- return _cache.get(key)
235
-
236
- def cache_set(key, value):
237
- _cache[key] = value
238
- return value
239
-
240
- def cache_clear():
241
- _cache.clear()
242
-
243
-
244
  # ─── Stage 4: NCC-based clustering ───────────────────────────────────────────
245
 
246
  def ncc_max(a: np.ndarray, b: np.ndarray) -> float:
247
- """Normalized cross-correlation peak. Amplitude-invariant.
248
- Returns 1.0 for identical waveforms at any amplitude."""
249
- a = a - a.mean()
250
- b = b - b.mean()
251
- norm = np.sqrt(np.dot(a, a) * np.dot(b, b))
 
252
  if norm < 1e-10: return 0.0
253
- n = max(len(a), len(b))
254
- a_pad = np.pad(a, (0, max(0, n - len(a))))
255
- b_pad = np.pad(b, (0, max(0, n - len(b))))
256
- cc = fftconvolve(a_pad, b_pad[::-1], mode='full')
257
  return float(np.max(np.abs(cc))) / norm
258
 
259
 
260
  def build_ncc_distance_matrix(hits: list, max_compare_samples: int = 8820) -> np.ndarray:
261
- """Build N×N distance matrix using NCC. d=0 identical, d=1 unrelated.
262
- Cached — recomputing is the most expensive step."""
263
- # Cache key from hit audio hashes
264
  key = ("ncc_dist", tuple(_audio_hash(h.audio) for h in hits), max_compare_samples)
265
  cached = cache_get(key)
266
  if cached is not None:
267
- print(f" Using cached NCC distance matrix")
268
- return cached
269
-
270
  N = len(hits)
271
  D = np.zeros((N, N), dtype=np.float32)
272
  for i in range(N):
273
  ai = hits[i].audio[:max_compare_samples]
274
  for j in range(i+1, N):
275
  bj = hits[j].audio[:max_compare_samples]
276
- ncc = ncc_max(ai, bj)
277
- D[i, j] = D[j, i] = max(0.0, 1.0 - ncc)
278
-
279
  return cache_set(key, D)
280
 
281
 
282
- def _agglom_at_threshold(D: np.ndarray, dist_threshold: float, linkage: str = 'average') -> np.ndarray:
283
- """Run agglomerative clustering at a specific threshold. Returns labels."""
284
- from sklearn.cluster import AgglomerativeClustering
285
- agg = AgglomerativeClustering(
286
- n_clusters=None,
287
- distance_threshold=max(0.001, dist_threshold),
288
- metric='precomputed',
289
- linkage=linkage,
290
- )
291
- return agg.fit_predict(D)
292
-
293
-
294
- def _labels_to_clusters(labels: np.ndarray, hits: list) -> list:
295
- """Convert sklearn labels to Cluster objects with majority-vote naming."""
296
  cluster_map = defaultdict(list)
297
- for i, lab in enumerate(labels):
298
- cluster_map[lab].append(i)
299
-
300
  clusters = []
301
  for _, indices in sorted(cluster_map.items()):
302
- label_votes = defaultdict(int)
303
- for idx in indices:
304
- label_votes[hits[idx].label] += 1
305
- majority_label = max(label_votes, key=label_votes.get)
306
- existing = sum(1 for c in clusters if c.label.rsplit('_', 1)[0] == majority_label)
307
- clusters.append(Cluster(
308
- cluster_id=len(clusters),
309
- label=f"{majority_label}_{existing}",
310
- hits=[hits[i] for i in indices],
311
- ))
312
-
313
  clusters.sort(key=lambda c: c.count, reverse=True)
314
- for i, c in enumerate(clusters):
315
- c.cluster_id = i
316
  return clusters
317
 
318
 
319
  def cluster_hits(hits: list, ncc_threshold: float = 0.80,
320
- max_compare_ms: float = 200.0,
321
  target_min: int = 0, target_max: int = 0,
322
  linkage: str = 'average') -> list:
323
- """Cluster hits by waveform identity using NCC + agglomerative clustering.
324
-
325
- If target_min/target_max are set (both > 0), ignores ncc_threshold and
326
- binary-searches the distance threshold to produce a cluster count within
327
- [target_min, target_max]. This is the most intuitive way to control output.
328
 
329
- linkage: 'average' (recommended tolerant of outlier pairs),
330
- 'complete' (strict any bad pair splits the cluster),
331
- 'single' (loose — chains distant points together).
332
  """
333
- if not hits:
334
- return []
335
-
336
- N = len(hits)
337
- sr = hits[0].sr
338
- max_samples = int(max_compare_ms / 1000.0 * sr)
339
 
340
- print(f"[Stage 4] NCC clustering ({N} hits, linkage={linkage})...")
 
 
341
 
342
- if N == 1:
343
- return [Cluster(cluster_id=0, label=f"{hits[0].label}_0", hits=[hits[0]])]
 
 
 
 
344
 
345
- # Build (or retrieve cached) distance matrix
346
  print(f" Computing {N*(N-1)//2} pairwise NCC distances...")
347
  D = build_ncc_distance_matrix(hits, max_compare_samples=max_samples)
348
 
349
- use_target_range = target_min > 0 and target_max > 0 and target_max >= target_min
350
  target_min = max(1, min(target_min, N))
351
  target_max = max(target_min, min(target_max, N))
352
 
353
- if use_target_range:
354
- # Binary search for the distance threshold that gives target cluster count
355
- print(f" Target range: {target_min}–{target_max} clusters, searching threshold...")
356
- lo, hi = 0.001, 1.0
357
- best_labels = None
358
- best_n = -1
359
 
360
- for _ in range(30): # max 30 binary search steps
 
 
 
361
  mid = (lo + hi) / 2
362
- labels = _agglom_at_threshold(D, mid, linkage)
 
 
363
  n = len(set(labels))
364
-
365
  if target_min <= n <= target_max:
366
- best_labels = labels
367
- best_n = n
368
- break
369
- elif n > target_max:
370
- # Too many clusters need higher threshold (merge more)
371
- lo = mid
372
- else:
373
- # Too few clusters need lower threshold (split more)
374
- hi = mid
375
-
376
- # Keep best attempt in range
377
- if best_labels is None or abs(n - (target_min + target_max) / 2) < abs(best_n - (target_min + target_max) / 2):
378
- best_labels = labels
379
- best_n = n
 
 
 
380
 
381
  labels = best_labels
382
- print(f" → threshold={mid:.4f}, {best_n} clusters")
383
  else:
384
- # Use fixed NCC threshold
385
  dist_threshold = max(0.001, 1.0 - ncc_threshold)
386
- print(f" Fixed threshold: NCC≥{ncc_threshold} (dist≤{dist_threshold:.3f})")
387
- labels = _agglom_at_threshold(D, dist_threshold, linkage)
388
-
389
- n_clusters = len(set(labels))
390
- print(f" ✓ {n_clusters} clusters")
391
 
 
392
  clusters = _labels_to_clusters(labels, hits)
393
- for c in clusters:
394
- print(f" {c.label}: {c.count} hits")
395
-
396
  return clusters
397
 
398
 
399
- # ─── Stage 5: Quality scoring & selection ─────────────────────────────────────
400
 
401
- def sample_quality_score(y: np.ndarray, sr: int, label: str = "other") -> dict:
402
- """Score a sample for production quality. Returns dict with total [0,100]."""
403
  import scipy.stats
404
  rms_env = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
405
- # Completeness
406
  if len(rms_env) >= 10:
407
  pk = np.argmax(rms_env); post = rms_env[pk:]
408
- tail_r = np.mean(post[-max(3, len(post)//5):]) / (rms_env[pk] + 1e-8)
409
- c1 = max(0, 1.0 - tail_r * 5)
410
- if len(post) >= 5:
411
- slope, _, r, _, _ = scipy.stats.linregress(np.arange(len(post)), np.log(post+1e-8))
412
- c2 = max(0, r**2) if slope < 0 else r**2 * 0.3
413
- else: c2 = 0.0
414
- else: c1, c2 = 0.5, 0.0
415
- completeness = c1 * 0.6 + c2 * 0.4
416
-
417
- # Cleanness
418
- snr = 10*np.log10(np.percentile(y**2, 99) / (np.percentile(y**2, 10) + 1e-12))
419
- n_snr = np.clip((snr - 10) / 40, 0, 1)
420
  onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
421
- if len(onsets) > 0:
422
- os_s = int(onsets[0])
423
- pre = y[max(0, os_s-int(sr*.02)):os_s]
424
- sig = y[os_s:os_s+int(sr*.1)]
425
- n_pre = np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12)) - 5)/30, 0, 1) \
426
- if len(pre) > 10 and len(sig) > 10 else 0.5
427
- else: n_pre = 0.5
428
- cleanness = n_snr * 0.5 + n_pre * 0.5
429
-
430
- # Onset quality
431
  oe = librosa.onset.onset_strength(y=y, sr=sr)
432
- sharpness = float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe) > 1 else 1.0
433
- onset_q = float(np.clip((sharpness - 1.0) / 5.0, 0, 1))
434
-
435
- total = (completeness * 0.30 + cleanness * 0.40 + onset_q * 0.20 + 0.5 * 0.10) * 100
436
- return {'total': float(total), 'completeness': float(completeness),
437
- 'cleanness': float(cleanness), 'onset_quality': float(onset_q)}
438
 
439
- def select_best(clusters: list):
440
  print(f"[Stage 5] Selecting best representatives...")
441
  for c in clusters:
442
- if c.count <= 1: c.best_hit_idx = 0; continue
443
- scores = [sample_quality_score(h.audio, h.sr, c.label.rsplit('_',1)[0])['total']
444
- for h in c.hits]
445
  c.best_hit_idx = int(np.argmax(scores))
446
 
447
 
448
  # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
449
 
450
- def synthesize_from_cluster(cluster: Cluster) -> Optional[np.ndarray]:
451
- if cluster.count < 2: return None
452
  tl = int(np.median([len(h.audio) for h in cluster.hits]))
453
  aligned, weights = [], []
454
  pp_target = None
455
  for i, h in enumerate(cluster.hits):
456
- a = h.audio.copy()
457
- pp = np.argmax(np.abs(a))
458
  if pp_target is None: pp_target = pp
459
- shift = pp_target - pp
460
- if shift > 0: a = np.pad(a, (shift, 0))
461
- elif shift < 0: a = a[-shift:]
462
- a = a[:tl] if len(a) >= tl else np.pad(a, (0, tl - len(a)))
463
  pk = np.abs(a).max()
464
- if pk > 0: a = a / pk
465
- aligned.append(a)
466
- weights.append(2.0 if i == cluster.best_hit_idx else 1.0)
467
- aligned = np.array(aligned)
468
- w = np.array(weights); w /= w.sum()
469
- synth = np.average(aligned, axis=0, weights=w)
470
- pk = np.abs(synth).max()
471
- return (synth * 0.95 / pk).astype(np.float32) if pk > 0 else synth.astype(np.float32)
472
 
473
 
474
  # ─── Stage 7: MIDI + rendering ───────────────────────────────────────────────
@@ -476,88 +373,80 @@ def synthesize_from_cluster(cluster: Cluster) -> Optional[np.ndarray]:
476
  def build_midi(clusters, bpm=120.0):
477
  import pretty_midi
478
  pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
479
- for i, c in enumerate(clusters): c.midi_note = min(36 + i, 127)
480
  inst = pretty_midi.Instrument(program=0, is_drum=True, name='Extracted Samples')
481
  pm.instruments.append(inst)
482
  for c in clusters:
483
  for h in c.hits:
484
- vel = max(1, min(127, int(h.rms_energy / 0.3 * 127)))
485
- inst.notes.append(pretty_midi.Note(velocity=vel, pitch=c.midi_note,
486
- start=h.onset_time,
487
- end=h.onset_time + max(h.duration, 0.05)))
488
- inst.notes.sort(key=lambda n: n.start)
489
- return pm
490
 
491
  def export_midi(clusters, output_path, bpm=120.0):
492
- pm = build_midi(clusters, bpm)
493
- pm.write(output_path)
494
- print(f" ✓ MIDI: {output_path} ({len(pm.instruments[0].notes)} notes)")
495
- return pm
496
 
497
  def detect_bpm(y, sr):
498
- cache_key = ("bpm", _audio_hash(y), sr)
499
- cached = cache_get(cache_key)
500
- if cached is not None:
501
- print(f" Using cached BPM: {cached}")
502
- return cached
503
- onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median)
504
- bpm = float(librosa.feature.tempo(onset_envelope=onset_env, sr=sr).item())
505
- _, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')
506
- if len(beats) > 2:
507
- ibi_bpm = 60.0 / float(np.median(np.diff(beats)))
508
- for c in [bpm, ibi_bpm]:
509
- if 70 <= c <= 200: bpm = c; break
510
  else:
511
- if bpm < 70: bpm *= 2
512
- elif bpm > 200: bpm /= 2
513
- return cache_set(cache_key, round(bpm, 1))
514
 
515
  def render_midi_with_samples(clusters, sr=44100):
516
- max_end = max((h.onset_time + h.duration for c in clusters for h in c.hits), default=1.0)
517
- buf = np.zeros(int((max_end + 1.0) * sr), dtype=np.float64)
518
  for c in clusters:
519
- sample = c.best_hit.audio.astype(np.float64)
520
- ref_e = c.best_hit.rms_energy if c.best_hit.rms_energy > 0 else 0.1
521
  for h in c.hits:
522
- vs = min(2.0, h.rms_energy / (ref_e + 1e-8)) ** 0.5
523
- s = int(h.onset_time * sr); e = s + len(sample)
524
- if e > len(buf): buf = np.concatenate([buf, np.zeros(e - len(buf))])
525
- buf[s:e] += sample * vs
526
- pk = np.abs(buf).max()
527
- return (buf / pk * 0.9).astype(np.float32) if pk > 1e-8 else buf.astype(np.float32)
528
 
529
  def build_sample_map(clusters):
530
- return {c.midi_note: {'label': c.label, 'count': c.count,
531
- 'duration_ms': int(c.best_hit.duration * 1000)} for c in clusters}
532
 
533
  def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
534
  import zipfile, tempfile, io
535
- zip_path = tempfile.mktemp(suffix='.zip')
536
- index = {'bpm': round(bpm, 1), 'sample_rate': sr,
537
- 'total_clusters': len(clusters),
538
- 'total_hits': sum(c.count for c in clusters), 'samples': {}}
539
- with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_STORED) as zf:
540
  for c in clusters:
541
- best = c.best_hit; fname = f"samples/{c.label}.wav"
542
- buf = io.BytesIO(); sf.write(buf, best.audio, sr, format='WAV', subtype='PCM_24')
543
- zf.writestr(fname, buf.getvalue())
544
- onset_times = sorted([h.onset_time for h in c.hits])
545
- index['samples'][c.label] = {
546
- 'file': fname, 'classification': c.label.rsplit('_', 1)[0],
547
- 'midi_note': c.midi_note, 'occurrences': c.count,
548
- 'onset_times_sec': [round(t, 4) for t in onset_times],
549
- 'duration_sec': round(best.duration, 4),
550
- 'rms_energy': round(best.rms_energy, 6),
551
- 'spectral_centroid_hz': round(best.spectral_centroid, 1),
552
  }
553
  if c.synthesized is not None:
554
- sf2 = f"samples/{c.label}__synthesized.wav"; b2 = io.BytesIO()
555
- sf.write(b2, c.synthesized, sr, format='WAV', subtype='PCM_24')
556
- zf.writestr(sf2, b2.getvalue())
557
- index['samples'][c.label]['synthesized_file'] = sf2
558
- zf.writestr('index.json', json.dumps(index, indent=2))
559
- if midi_path and os.path.exists(midi_path): zf.write(midi_path, 'reconstruction.mid')
560
  if rendered_audio is not None:
561
- rb = io.BytesIO(); sf.write(rb, rendered_audio, sr, format='WAV', subtype='PCM_16')
562
- zf.writestr('rendered_reconstruction.wav', rb.getvalue())
563
  return zip_path
 
1
  #!/usr/bin/env python3
2
  """
3
+ Sample Extractor v6Tested on real hardstyle tracks.
4
+
5
+ Fixes from v5:
6
+ - NCC compare window auto-scales to median hit length (no more zero-pad inflation)
7
+ - Target range uses n_clusters directly when binary search hits a cliff
8
+ - Better defaults for real music (delta=0.12, energy=-35, min_gap=0.03)
9
+ - Caching for stem separation, BPM, NCC distance matrix
 
 
 
 
 
 
 
 
10
  """
11
 
12
+ import argparse, json, os, sys, warnings, hashlib
13
  from collections import defaultdict
14
  from dataclasses import dataclass, field
15
  from pathlib import Path
 
41
  def count(self) -> int: return len(self.hits)
42
 
43
 
44
+ DEMUCS_MODELS = ["htdemucs", "htdemucs_ft", "htdemucs_6s", "mdx", "mdx_extra", "mdx_extra_q"]
 
 
 
 
 
 
 
 
 
45
  DEMUCS_STEMS = {
46
+ "htdemucs": ["drums","bass","other","vocals"], "htdemucs_ft": ["drums","bass","other","vocals"],
47
+ "htdemucs_6s": ["drums","bass","other","vocals","guitar","piano"],
48
+ "mdx": ["drums","bass","other","vocals"], "mdx_extra": ["drums","bass","other","vocals"],
49
+ "mdx_extra_q": ["drums","bass","other","vocals"],
 
 
50
  }
51
 
52
 
53
+ # ─── Caching ──────────────────────────────────────────────────────────────────
54
+
55
+ _cache = {}
56
+
57
+ def _audio_hash(audio: np.ndarray) -> str:
58
+ return hashlib.md5(audio[:4000].tobytes()).hexdigest()
59
+
60
+ def cache_get(key): return _cache.get(key)
61
+ def cache_set(key, value): _cache[key] = value; return value
62
+ def cache_clear(): _cache.clear()
63
+
64
+
65
  # ─── Stage 1: Stem separation ────────────────────────────────────────────────
66
 
67
  def extract_stem(audio_path: str, stem: str = "drums", device: str = "cpu",
68
  model_name: str = "htdemucs_ft", shifts: int = 1,
69
  overlap: float = 0.25) -> tuple:
 
70
  if stem == "all":
71
  y, sr = librosa.load(audio_path, sr=44100, mono=True)
72
  return y.astype(np.float32), sr
73
 
 
74
  with open(audio_path, 'rb') as f:
75
+ file_hash = hashlib.md5(f.read(200000)).hexdigest()
76
+ ck = ("stem", file_hash, stem, model_name, shifts, overlap)
77
+ cached = cache_get(ck)
78
  if cached is not None:
79
+ print(f"[Stage 1] Using cached {stem} stem"); return cached
 
80
 
81
  from demucs.pretrained import get_model
82
  from demucs.apply import apply_model
83
+ print(f"[Stage 1] Extracting '{stem}' with {model_name}...")
84
+ model = get_model(model_name); model.eval().to(device)
 
 
85
  sr = model.samplerate
 
86
  if stem not in model.sources:
87
+ raise ValueError(f"'{stem}' not in {model.sources}")
 
88
  audio_np, _ = librosa.load(audio_path, sr=sr, mono=False)
89
+ if audio_np.ndim == 1: audio_np = np.stack([audio_np, audio_np])
90
+ elif audio_np.shape[0] > 2: audio_np = audio_np[:2]
91
+ elif audio_np.shape[0] == 1: audio_np = np.concatenate([audio_np, audio_np], axis=0)
 
 
 
 
92
  wav = torch.from_numpy(audio_np).float().unsqueeze(0).to(device)
93
  with torch.no_grad():
94
+ sources = apply_model(model, wav, device=device, shifts=shifts, split=True, overlap=overlap)
95
+ result = sources[0, model.sources.index(stem)].mean(dim=0).cpu().numpy()
 
 
 
96
  print(f" ✓ {stem}: {len(result)/sr:.1f}s")
97
+ return cache_set(ck, (result.astype(np.float32), sr))
 
98
 
99
 
100
  # ─── Stage 2: Onset detection ────────────────────────────────────────────────
101
 
102
  def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
103
  min_dur: float = 0.02, max_dur: float = 1.5,
104
+ min_gap: float = 0.03, energy_threshold_db: float = -35.0,
105
  mode: str = "auto", backtrack: bool = True,
106
+ onset_delta: float = 0.12) -> list:
107
+ print(f"[Stage 2] Detecting onsets (mode={mode}, delta={onset_delta}, energy≥{energy_threshold_db}dB)...")
 
 
108
  if mode == "percussive":
109
  onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
110
  elif mode == "harmonic":
 
112
  onset_env = librosa.onset.onset_strength(y=y_harm, sr=sr, fmax=8000, lag=2, max_size=3)
113
  elif mode == "broadband":
114
  onset_env = librosa.onset.onset_strength(y=y, sr=sr)
115
+ else:
116
  y_harm, y_perc = librosa.effects.hpss(y)
117
  envs = [
118
  librosa.onset.onset_strength(y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median),
119
  librosa.onset.onset_strength(y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median),
120
+ librosa.onset.onset_strength(y=y, sr=sr, fmin=4000, fmax=min(sr//2,20000), aggregate=np.median),
121
  librosa.onset.onset_strength(y=y_harm, sr=sr, lag=2),
122
  ]
123
+ def _n(x): m=x.max(); return x/m if m>0 else x
 
124
  onset_env = np.maximum.reduce([_n(e) for e in envs])
125
 
126
  wait = max(1, int(min_gap * sr / 512))
127
+ frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, wait=wait,
 
128
  pre_avg=3, post_avg=3, pre_max=3, post_max=5,
129
  delta=onset_delta, backtrack=backtrack, units='frames')
130
  times = librosa.frames_to_time(frames, sr=sr)
 
134
  hits = []
135
  for i, t in enumerate(times):
136
  s = max(0, int((t - pre_pad) * sr))
137
+ e = min(int(times[i+1]*sr) if i+1<len(times) else len(y), s+int(max_dur*sr))
 
 
 
138
  seg = y[s:e]
139
+ if len(seg) < int(min_dur*sr): continue
140
  rms = np.sqrt(np.mean(seg**2))
141
  if rms < threshold: continue
142
+ fl = min(int(0.005*sr), len(seg)//4)
143
+ if fl > 0: seg = seg.copy(); seg[-fl:] *= np.linspace(1, 0, fl)
 
144
  sc = float(librosa.feature.spectral_centroid(y=seg, sr=sr).mean())
145
  hits.append(Hit(audio=seg, sr=sr, onset_time=t, duration=len(seg)/sr,
146
  index=len(hits), rms_energy=float(rms), spectral_centroid=sc))
 
151
  # ─── Stage 3: Classification ───────────────────────────────────��─────────────
152
 
153
  LABEL_RULES = [
154
+ ("kick", lambda lr,mr,hr,c,zcr,d: lr>0.5 and c<800),
155
+ ("hihat_closed", lambda lr,mr,hr,c,zcr,d: hr>0.35 and c>4000 and d<0.15),
156
+ ("hihat_open", lambda lr,mr,hr,c,zcr,d: hr>0.35 and c>4000 and d>=0.15),
157
+ ("cymbal", lambda lr,mr,hr,c,zcr,d: hr>0.25 and c>3000),
158
+ ("snare", lambda lr,mr,hr,c,zcr,d: mr>0.4 and zcr>0.1 and c>1000),
159
+ ("tom", lambda lr,mr,hr,c,zcr,d: lr>0.3 and mr>0.3 and c<1500),
160
+ ("bass", lambda lr,mr,hr,c,zcr,d: lr>0.6 and c<400 and d>0.2),
161
+ ("vocal", lambda lr,mr,hr,c,zcr,d: mr>0.5 and 500<c<3000 and zcr<0.15),
162
+ ("bright", lambda lr,mr,hr,c,zcr,d: c>2500),
163
+ ("mid", lambda lr,mr,hr,c,zcr,d: c>800),
164
  ]
165
 
166
  def classify_hit(hit: Hit) -> str:
167
  y, sr = hit.audio, hit.sr
168
  D = np.abs(librosa.stft(y, n_fft=2048))
169
  freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
170
+ le = np.sum(D[(freqs>=20)&(freqs<200)]**2)
171
+ me = np.sum(D[(freqs>=200)&(freqs<4000)]**2)
172
+ he = np.sum(D[(freqs>=4000)]**2)
173
+ total = le+me+he+1e-10; lr,mr,hr = le/total,me/total,he/total
 
174
  zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
175
  for name, fn in LABEL_RULES:
176
+ if fn(lr,mr,hr,hit.spectral_centroid,zcr,hit.duration): return name
 
177
  return "other"
178
 
179
  def classify_hits(hits: list) -> list:
 
180
  print(f"[Stage 3] Classifying {len(hits)} hits...")
181
+ for h in hits: h.label = classify_hit(h)
 
182
  counts = defaultdict(int)
183
  for h in hits: counts[h.label] += 1
184
+ for l, c in sorted(counts.items(), key=lambda x: -x[1]): print(f" {l}: {c}")
 
185
  return hits
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # ─── Stage 4: NCC-based clustering ───────────────────────────────────────────
189
 
190
  def ncc_max(a: np.ndarray, b: np.ndarray) -> float:
191
+ """NCC peak. Amplitude-invariant. Compares the shorter length only."""
192
+ # Use the shorter clip's length no zero-padding inflation
193
+ n = min(len(a), len(b))
194
+ a, b = a[:n].copy(), b[:n].copy()
195
+ a -= a.mean(); b -= b.mean()
196
+ norm = np.sqrt(np.dot(a,a) * np.dot(b,b))
197
  if norm < 1e-10: return 0.0
198
+ cc = fftconvolve(a, b[::-1], mode='full')
 
 
 
199
  return float(np.max(np.abs(cc))) / norm
200
 
201
 
202
  def build_ncc_distance_matrix(hits: list, max_compare_samples: int = 8820) -> np.ndarray:
203
+ """Cached NCC distance matrix. Auto-scales compare window to hit lengths."""
 
 
204
  key = ("ncc_dist", tuple(_audio_hash(h.audio) for h in hits), max_compare_samples)
205
  cached = cache_get(key)
206
  if cached is not None:
207
+ print(f" Using cached NCC distance matrix"); return cached
 
 
208
  N = len(hits)
209
  D = np.zeros((N, N), dtype=np.float32)
210
  for i in range(N):
211
  ai = hits[i].audio[:max_compare_samples]
212
  for j in range(i+1, N):
213
  bj = hits[j].audio[:max_compare_samples]
214
+ D[i,j] = D[j,i] = max(0.0, 1.0 - ncc_max(ai, bj))
 
 
215
  return cache_set(key, D)
216
 
217
 
218
+ def _labels_to_clusters(labels, hits):
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  cluster_map = defaultdict(list)
220
+ for i, lab in enumerate(labels): cluster_map[lab].append(i)
 
 
221
  clusters = []
222
  for _, indices in sorted(cluster_map.items()):
223
+ votes = defaultdict(int)
224
+ for idx in indices: votes[hits[idx].label] += 1
225
+ majority = max(votes, key=votes.get)
226
+ existing = sum(1 for c in clusters if c.label.rsplit('_',1)[0] == majority)
227
+ clusters.append(Cluster(cluster_id=len(clusters), label=f"{majority}_{existing}",
228
+ hits=[hits[i] for i in indices]))
 
 
 
 
 
229
  clusters.sort(key=lambda c: c.count, reverse=True)
230
+ for i, c in enumerate(clusters): c.cluster_id = i
 
231
  return clusters
232
 
233
 
234
  def cluster_hits(hits: list, ncc_threshold: float = 0.80,
235
+ max_compare_ms: float = 0,
236
  target_min: int = 0, target_max: int = 0,
237
  linkage: str = 'average') -> list:
238
+ """NCC clustering with target range support.
 
 
 
 
239
 
240
+ max_compare_ms: 0 = auto (use median hit length). Otherwise milliseconds.
241
+ target_min/max: if both > 0, find a cluster count in this range.
 
242
  """
243
+ from sklearn.cluster import AgglomerativeClustering
 
 
 
 
 
244
 
245
+ if not hits: return []
246
+ N = len(hits); sr = hits[0].sr
247
+ if N == 1: return [Cluster(cluster_id=0, label=f"{hits[0].label}_0", hits=[hits[0]])]
248
 
249
+ # Auto-scale compare window to median hit length
250
+ if max_compare_ms <= 0:
251
+ median_len = int(np.median([len(h.audio) for h in hits]))
252
+ max_samples = max(int(0.03 * sr), median_len) # at least 30ms
253
+ else:
254
+ max_samples = int(max_compare_ms / 1000.0 * sr)
255
 
256
+ print(f"[Stage 4] NCC clustering ({N} hits, compare={max_samples/sr*1000:.0f}ms, linkage={linkage})...")
257
  print(f" Computing {N*(N-1)//2} pairwise NCC distances...")
258
  D = build_ncc_distance_matrix(hits, max_compare_samples=max_samples)
259
 
260
+ use_target = target_min > 0 and target_max > 0 and target_max >= target_min
261
  target_min = max(1, min(target_min, N))
262
  target_max = max(target_min, min(target_max, N))
263
 
264
+ if use_target:
265
+ print(f" Target: {target_min}–{target_max} clusters")
 
 
 
 
266
 
267
+ # Strategy 1: Binary search on distance threshold
268
+ lo, hi = 0.001, 1.0
269
+ best_labels, best_n, best_dist = None, -1, 0.5
270
+ for _ in range(30):
271
  mid = (lo + hi) / 2
272
+ agg = AgglomerativeClustering(n_clusters=None, distance_threshold=max(0.001, mid),
273
+ metric='precomputed', linkage=linkage)
274
+ labels = agg.fit_predict(D)
275
  n = len(set(labels))
 
276
  if target_min <= n <= target_max:
277
+ best_labels, best_n, best_dist = labels, n, mid; break
278
+ elif n > target_max: lo = mid
279
+ else: hi = mid
280
+ if best_labels is None or abs(n-(target_min+target_max)/2) < abs(best_n-(target_min+target_max)/2):
281
+ best_labels, best_n, best_dist = labels, n, mid
282
+
283
+ # Strategy 2: If binary search didn't land in range, use n_clusters directly
284
+ if best_n < target_min or best_n > target_max:
285
+ target_mid = (target_min + target_max) // 2
286
+ target_mid = min(target_mid, N - 1)
287
+ print(f" Binary search got {best_n}, falling back to n_clusters={target_mid}")
288
+ try:
289
+ agg = AgglomerativeClustering(n_clusters=target_mid, metric='precomputed', linkage=linkage)
290
+ best_labels = agg.fit_predict(D)
291
+ best_n = target_mid
292
+ except Exception as e:
293
+ print(f" n_clusters fallback failed: {e}")
294
 
295
  labels = best_labels
296
+ print(f" → {best_n} clusters (dist_threshold={best_dist:.4f})")
297
  else:
 
298
  dist_threshold = max(0.001, 1.0 - ncc_threshold)
299
+ print(f" Fixed: NCC≥{ncc_threshold} (dist≤{dist_threshold:.3f})")
300
+ agg = AgglomerativeClustering(n_clusters=None, distance_threshold=dist_threshold,
301
+ metric='precomputed', linkage=linkage)
302
+ labels = agg.fit_predict(D)
 
303
 
304
+ print(f" ✓ {len(set(labels))} clusters")
305
  clusters = _labels_to_clusters(labels, hits)
306
+ for c in clusters: print(f" {c.label}: {c.count} hits")
 
 
307
  return clusters
308
 
309
 
310
+ # ─── Stage 5: Quality scoring ────────────────────────────────────────────────
311
 
312
+ def sample_quality_score(y, sr, label="other"):
 
313
  import scipy.stats
314
  rms_env = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
 
315
  if len(rms_env) >= 10:
316
  pk = np.argmax(rms_env); post = rms_env[pk:]
317
+ tail_r = np.mean(post[-max(3,len(post)//5):])/(rms_env[pk]+1e-8)
318
+ c1 = max(0, 1.0-tail_r*5)
319
+ if len(post)>=5:
320
+ slope,_,r,_,_ = scipy.stats.linregress(np.arange(len(post)), np.log(post+1e-8))
321
+ c2 = max(0,r**2) if slope<0 else r**2*0.3
322
+ else: c2=0.0
323
+ else: c1,c2 = 0.5,0.0
324
+ completeness = c1*0.6+c2*0.4
325
+ snr = 10*np.log10(np.percentile(y**2,99)/(np.percentile(y**2,10)+1e-12))
326
+ n_snr = np.clip((snr-10)/40,0,1)
 
 
327
  onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
328
+ if len(onsets)>0:
329
+ os_s=int(onsets[0]); pre=y[max(0,os_s-int(sr*.02)):os_s]; sig=y[os_s:os_s+int(sr*.1)]
330
+ n_pre = np.clip((-10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))-5)/30,0,1) \
331
+ if len(pre)>10 and len(sig)>10 else 0.5
332
+ else: n_pre=0.5
333
+ cleanness = n_snr*0.5+n_pre*0.5
 
 
 
 
334
  oe = librosa.onset.onset_strength(y=y, sr=sr)
335
+ sharpness = float(np.max(oe)/(np.mean(oe)+1e-8)) if len(oe)>1 else 1.0
336
+ onset_q = float(np.clip((sharpness-1.0)/5.0,0,1))
337
+ total = (completeness*0.30+cleanness*0.40+onset_q*0.20+0.5*0.10)*100
338
+ return {'total':float(total),'completeness':float(completeness),
339
+ 'cleanness':float(cleanness),'onset_quality':float(onset_q)}
 
340
 
341
+ def select_best(clusters):
342
  print(f"[Stage 5] Selecting best representatives...")
343
  for c in clusters:
344
+ if c.count<=1: c.best_hit_idx=0; continue
345
+ scores = [sample_quality_score(h.audio,h.sr,c.label.rsplit('_',1)[0])['total'] for h in c.hits]
 
346
  c.best_hit_idx = int(np.argmax(scores))
347
 
348
 
349
  # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
350
 
351
+ def synthesize_from_cluster(cluster):
352
+ if cluster.count<2: return None
353
  tl = int(np.median([len(h.audio) for h in cluster.hits]))
354
  aligned, weights = [], []
355
  pp_target = None
356
  for i, h in enumerate(cluster.hits):
357
+ a = h.audio.copy(); pp = np.argmax(np.abs(a))
 
358
  if pp_target is None: pp_target = pp
359
+ shift = pp_target-pp
360
+ if shift>0: a=np.pad(a,(shift,0))
361
+ elif shift<0: a=a[-shift:]
362
+ a = a[:tl] if len(a)>=tl else np.pad(a,(0,tl-len(a)))
363
  pk = np.abs(a).max()
364
+ if pk>0: a=a/pk
365
+ aligned.append(a); weights.append(2.0 if i==cluster.best_hit_idx else 1.0)
366
+ aligned=np.array(aligned); w=np.array(weights); w/=w.sum()
367
+ synth=np.average(aligned,axis=0,weights=w); pk=np.abs(synth).max()
368
+ return (synth*0.95/pk).astype(np.float32) if pk>0 else synth.astype(np.float32)
 
 
 
369
 
370
 
371
  # ─── Stage 7: MIDI + rendering ───────────────────────────────────────────────
 
373
  def build_midi(clusters, bpm=120.0):
374
  import pretty_midi
375
  pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
376
+ for i,c in enumerate(clusters): c.midi_note=min(36+i,127)
377
  inst = pretty_midi.Instrument(program=0, is_drum=True, name='Extracted Samples')
378
  pm.instruments.append(inst)
379
  for c in clusters:
380
  for h in c.hits:
381
+ vel=max(1,min(127,int(h.rms_energy/0.3*127)))
382
+ inst.notes.append(pretty_midi.Note(velocity=vel,pitch=c.midi_note,
383
+ start=h.onset_time,end=h.onset_time+max(h.duration,0.05)))
384
+ inst.notes.sort(key=lambda n: n.start); return pm
 
 
385
 
386
  def export_midi(clusters, output_path, bpm=120.0):
387
+ pm=build_midi(clusters,bpm); pm.write(output_path)
388
+ print(f" ✓ MIDI: {output_path} ({len(pm.instruments[0].notes)} notes)"); return pm
 
 
389
 
390
  def detect_bpm(y, sr):
391
+ ck=("bpm",_audio_hash(y),sr); cached=cache_get(ck)
392
+ if cached is not None: print(f" Cached BPM: {cached}"); return cached
393
+ onset_env=librosa.onset.onset_strength(y=y,sr=sr,aggregate=np.median)
394
+ bpm=float(librosa.feature.tempo(onset_envelope=onset_env,sr=sr).item())
395
+ _,beats=librosa.beat.beat_track(onset_envelope=onset_env,sr=sr,units='time')
396
+ if len(beats)>2:
397
+ ibi=60.0/float(np.median(np.diff(beats)))
398
+ for c in [bpm,ibi]:
399
+ if 70<=c<=200: bpm=c; break
 
 
 
400
  else:
401
+ if bpm<70: bpm*=2
402
+ elif bpm>200: bpm/=2
403
+ return cache_set(ck, round(bpm,1))
404
 
405
  def render_midi_with_samples(clusters, sr=44100):
406
+ max_end=max((h.onset_time+h.duration for c in clusters for h in c.hits),default=1.0)
407
+ buf=np.zeros(int((max_end+1.0)*sr),dtype=np.float64)
408
  for c in clusters:
409
+ sample=c.best_hit.audio.astype(np.float64)
410
+ ref_e=c.best_hit.rms_energy if c.best_hit.rms_energy>0 else 0.1
411
  for h in c.hits:
412
+ vs=min(2.0,h.rms_energy/(ref_e+1e-8))**0.5
413
+ s=int(h.onset_time*sr); e=s+len(sample)
414
+ if e>len(buf): buf=np.concatenate([buf,np.zeros(e-len(buf))])
415
+ buf[s:e]+=sample*vs
416
+ pk=np.abs(buf).max()
417
+ return (buf/pk*0.9).astype(np.float32) if pk>1e-8 else buf.astype(np.float32)
418
 
419
  def build_sample_map(clusters):
420
+ return {c.midi_note:{'label':c.label,'count':c.count,
421
+ 'duration_ms':int(c.best_hit.duration*1000)} for c in clusters}
422
 
423
  def build_archive(clusters, bpm, sr, midi_path=None, rendered_audio=None):
424
  import zipfile, tempfile, io
425
+ zip_path=tempfile.mktemp(suffix='.zip')
426
+ index={'bpm':round(bpm,1),'sample_rate':sr,'total_clusters':len(clusters),
427
+ 'total_hits':sum(c.count for c in clusters),'samples':{}}
428
+ with zipfile.ZipFile(zip_path,'w',compression=zipfile.ZIP_STORED) as zf:
 
429
  for c in clusters:
430
+ best=c.best_hit; fname=f"samples/{c.label}.wav"
431
+ buf=io.BytesIO(); sf.write(buf,best.audio,sr,format='WAV',subtype='PCM_24')
432
+ zf.writestr(fname,buf.getvalue())
433
+ onset_times=sorted([h.onset_time for h in c.hits])
434
+ index['samples'][c.label]={
435
+ 'file':fname,'classification':c.label.rsplit('_',1)[0],
436
+ 'midi_note':c.midi_note,'occurrences':c.count,
437
+ 'onset_times_sec':[round(t,4) for t in onset_times],
438
+ 'duration_sec':round(best.duration,4),
439
+ 'rms_energy':round(best.rms_energy,6),
440
+ 'spectral_centroid_hz':round(best.spectral_centroid,1),
441
  }
442
  if c.synthesized is not None:
443
+ sf2=f"samples/{c.label}__synthesized.wav"; b2=io.BytesIO()
444
+ sf.write(b2,c.synthesized,sr,format='WAV',subtype='PCM_24')
445
+ zf.writestr(sf2,b2.getvalue())
446
+ index['samples'][c.label]['synthesized_file']=sf2
447
+ zf.writestr('index.json',json.dumps(index,indent=2))
448
+ if midi_path and os.path.exists(midi_path): zf.write(midi_path,'reconstruction.mid')
449
  if rendered_audio is not None:
450
+ rb=io.BytesIO(); sf.write(rb,rendered_audio,sr,format='WAV',subtype='PCM_16')
451
+ zf.writestr('rendered_reconstruction.wav',rb.getvalue())
452
  return zip_path