rikhoffbauer2 commited on
Commit
fcf261a
Β·
verified Β·
1 Parent(s): 2a334ed

v2: Update sample_extractor.py

Browse files
Files changed (1) hide show
  1. sample_extractor.py +609 -0
sample_extractor.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sample Extractor β€” Generalized audio sample extraction pipeline.
4
+
5
+ Extracts any distinct sound (drum hits, vocal stabs, guitar plucks, SFX, etc.)
6
+ from audio, clusters identical occurrences, picks the best representative,
7
+ and reconstructs the song as MIDI.
8
+
9
+ Stages:
10
+ 1. STEM SEPARATION β€” HTDemucs isolates target stem (optional)
11
+ 2. ONSET DETECTION β€” Adaptive multi-method detection for any sound type
12
+ 3. SPECTRAL CLASSIFICATION β€” Label sounds by frequency profile
13
+ 4. OVERLAP SEPARATION β€” Decompose simultaneous sounds via spectral bands
14
+ 5. EMBEDDING & CLUSTERING β€” Group identical sounds, auto-K
15
+ 6. QUALITY SCORING β€” Completeness + cleanness + onset sharpness
16
+ 7. SYNTHESIS β€” Peak-aligned weighted average of cluster members
17
+ 8. MIDI RECONSTRUCTION β€” Map clusters back to timeline as .mid
18
+ """
19
+
20
+ import argparse, json, os, sys, warnings
21
+ from collections import defaultdict
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+ from typing import Optional
25
+ import librosa, numpy as np, soundfile as sf, torch
26
+
27
+ warnings.filterwarnings("ignore", category=FutureWarning)
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
30
+ # ─── Data structures ─────────────────────────────────────────────────────────
31
+
32
+ @dataclass
33
+ class Hit:
34
+ """A single detected audio event."""
35
+ audio: np.ndarray
36
+ sr: int
37
+ onset_time: float
38
+ duration: float
39
+ index: int
40
+ rms_energy: float = 0.0
41
+ spectral_centroid: float = 0.0
42
+ label: str = ""
43
+ embedding: Optional[np.ndarray] = None
44
+ cluster_id: int = -1
45
+
46
+ def save(self, path: str):
47
+ sf.write(path, self.audio, self.sr, subtype='PCM_24')
48
+
49
+
50
+ @dataclass
51
+ class Cluster:
52
+ """A group of similar sounds."""
53
+ cluster_id: int
54
+ label: str
55
+ hits: list = field(default_factory=list)
56
+ best_hit_idx: int = 0
57
+ synthesized: Optional[np.ndarray] = None
58
+ midi_note: int = 60 # assigned during MIDI export
59
+
60
+ @property
61
+ def best_hit(self) -> Hit:
62
+ return self.hits[self.best_hit_idx]
63
+
64
+ @property
65
+ def count(self) -> int:
66
+ return len(self.hits)
67
+
68
+
69
+ # ─── Stage 1: Stem separation ────────────────────────────────────────────────
70
+
71
+ def extract_stem(audio_path: str, stem: str = "drums", device: str = "cpu") -> tuple:
72
+ """Extract a stem using HTDemucs. stem: drums|bass|vocals|other|all"""
73
+ if stem == "all":
74
+ y, sr = librosa.load(audio_path, sr=44100, mono=True)
75
+ return y.astype(np.float32), sr
76
+
77
+ from demucs.pretrained import get_model
78
+ from demucs.apply import apply_model
79
+
80
+ print(f"[Stage 1] Extracting '{stem}' stem with HTDemucs...")
81
+ for name in ["htdemucs_ft", "htdemucs"]:
82
+ try:
83
+ model = get_model(name)
84
+ break
85
+ except Exception:
86
+ continue
87
+ else:
88
+ raise RuntimeError("Could not load Demucs model")
89
+
90
+ model.eval().to(device)
91
+ sr = model.samplerate
92
+
93
+ audio_np, _ = librosa.load(audio_path, sr=sr, mono=False)
94
+ if audio_np.ndim == 1:
95
+ audio_np = np.stack([audio_np, audio_np])
96
+ elif audio_np.shape[0] > 2:
97
+ audio_np = audio_np[:2]
98
+ elif audio_np.shape[0] == 1:
99
+ audio_np = np.concatenate([audio_np, audio_np], axis=0)
100
+
101
+ wav = torch.from_numpy(audio_np).float().unsqueeze(0).to(device)
102
+ with torch.no_grad():
103
+ sources = apply_model(model, wav, device=device, shifts=1, split=True, overlap=0.25)
104
+
105
+ idx = model.sources.index(stem)
106
+ result = sources[0, idx].mean(dim=0).cpu().numpy()
107
+ print(f" βœ“ Extracted {stem}: {len(result)/sr:.1f}s")
108
+ return result.astype(np.float32), sr
109
+
110
+
111
+ # ─── Stage 2: Onset detection (generalized) ──────────────────────────────────
112
+
113
+ def detect_onsets(y: np.ndarray, sr: int, pre_pad: float = 0.005,
114
+ min_dur: float = 0.02, max_dur: float = 1.5,
115
+ min_gap: float = 0.015, energy_threshold_db: float = -45.0,
116
+ mode: str = "auto") -> list:
117
+ """
118
+ Detect audio event onsets. mode: auto|percussive|harmonic|broadband
119
+ 'auto' uses HPSS dual-channel detection (best general-purpose).
120
+ """
121
+ print(f"[Stage 2] Detecting onsets (mode={mode})...")
122
+
123
+ if mode == "percussive":
124
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000)
125
+ elif mode == "harmonic":
126
+ y_harm, _ = librosa.effects.hpss(y)
127
+ onset_env = librosa.onset.onset_strength(y=y_harm, sr=sr, fmax=8000, lag=2, max_size=3)
128
+ elif mode == "broadband":
129
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
130
+ else: # auto: multi-band max
131
+ y_harm, y_perc = librosa.effects.hpss(y)
132
+ env_low = librosa.onset.onset_strength(y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median)
133
+ env_mid = librosa.onset.onset_strength(y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median)
134
+ env_high = librosa.onset.onset_strength(y=y, sr=sr, fmin=4000, fmax=min(sr//2, 20000), aggregate=np.median)
135
+ env_harm = librosa.onset.onset_strength(y=y_harm, sr=sr, lag=2)
136
+ def _n(x):
137
+ m = x.max(); return x/m if m > 0 else x
138
+ onset_env = np.maximum(np.maximum(_n(env_low), _n(env_mid)),
139
+ np.maximum(_n(env_high), _n(env_harm)))
140
+
141
+ wait = max(1, int(min_gap * sr / 512))
142
+ frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr, wait=wait,
143
+ pre_avg=3, post_avg=3, pre_max=3, post_max=5,
144
+ backtrack=True, units='frames')
145
+ times = librosa.frames_to_time(frames, sr=sr)
146
+ print(f" Raw onsets: {len(times)}")
147
+
148
+ threshold = 10 ** (energy_threshold_db / 20)
149
+ hits = []
150
+ for i, t in enumerate(times):
151
+ s = max(0, int((t - pre_pad) * sr))
152
+ if i + 1 < len(times):
153
+ e = min(int(times[i+1] * sr), s + int(max_dur * sr))
154
+ else:
155
+ e = min(len(y), s + int(max_dur * sr))
156
+ seg = y[s:e]
157
+ if len(seg) < int(min_dur * sr):
158
+ continue
159
+ rms = np.sqrt(np.mean(seg**2))
160
+ if rms < threshold:
161
+ continue
162
+ # Fade out
163
+ fl = min(int(0.005 * sr), len(seg) // 4)
164
+ if fl > 0:
165
+ seg = seg.copy()
166
+ seg[-fl:] *= np.linspace(1, 0, fl)
167
+ sc = float(librosa.feature.spectral_centroid(y=seg, sr=sr).mean())
168
+ hits.append(Hit(audio=seg, sr=sr, onset_time=t, duration=len(seg)/sr,
169
+ index=len(hits), rms_energy=float(rms), spectral_centroid=sc))
170
+
171
+ print(f" βœ“ Valid hits: {len(hits)}")
172
+ return hits
173
+
174
+
175
+ # ─── Stage 3: Classification (generalized) ───────────────────────────────────
176
+
177
+ LABEL_RULES = [
178
+ # (name, condition_fn)
179
+ ("kick", lambda lr, mr, hr, c, zcr, d: lr > 0.5 and c < 800),
180
+ ("hihat_closed", lambda lr, mr, hr, c, zcr, d: hr > 0.35 and c > 4000 and d < 0.15),
181
+ ("hihat_open", lambda lr, mr, hr, c, zcr, d: hr > 0.35 and c > 4000 and d >= 0.15),
182
+ ("cymbal", lambda lr, mr, hr, c, zcr, d: hr > 0.25 and c > 3000),
183
+ ("snare", lambda lr, mr, hr, c, zcr, d: mr > 0.4 and zcr > 0.1 and c > 1000),
184
+ ("tom", lambda lr, mr, hr, c, zcr, d: lr > 0.3 and mr > 0.3 and c < 1500),
185
+ ("bass", lambda lr, mr, hr, c, zcr, d: lr > 0.6 and c < 400 and d > 0.2),
186
+ ("vocal", lambda lr, mr, hr, c, zcr, d: mr > 0.5 and c > 500 and c < 3000 and zcr < 0.15),
187
+ ("bright", lambda lr, mr, hr, c, zcr, d: c > 2500),
188
+ ("mid", lambda lr, mr, hr, c, zcr, d: c > 800),
189
+ ]
190
+
191
+ def classify_hit(hit: Hit) -> str:
192
+ y, sr = hit.audio, hit.sr
193
+ D = np.abs(librosa.stft(y, n_fft=2048))
194
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
195
+ le = np.sum(D[(freqs >= 20) & (freqs < 200)]**2)
196
+ me = np.sum(D[(freqs >= 200) & (freqs < 4000)]**2)
197
+ he = np.sum(D[(freqs >= 4000)]**2)
198
+ total = le + me + he + 1e-10
199
+ lr, mr, hr = le/total, me/total, he/total
200
+ zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
201
+ for name, fn in LABEL_RULES:
202
+ if fn(lr, mr, hr, hit.spectral_centroid, zcr, hit.duration):
203
+ return name
204
+ return "other"
205
+
206
+
207
+ def spectral_decompose(hit: Hit, threshold: float = 0.15) -> dict:
208
+ """Split a hit into spectral sub-bands if multiple bands are significant."""
209
+ y, sr = hit.audio, hit.sr
210
+ D = librosa.stft(y, n_fft=2048)
211
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
212
+ bands = {"low": (20, 250), "mid": (250, 4000), "high": (4000, sr//2)}
213
+ results = {}
214
+ for name, (lo, hi) in bands.items():
215
+ mask = (freqs >= lo) & (freqs <= hi)
216
+ Db = np.zeros_like(D); Db[mask] = D[mask]
217
+ ab = librosa.istft(Db, length=len(y))
218
+ if np.sqrt(np.mean(ab**2)) > 0.001:
219
+ results[name] = ab
220
+ return results
221
+
222
+
223
+ def classify_and_separate(hits: list, separate_overlaps: bool = True,
224
+ overlap_threshold: float = 0.15) -> list:
225
+ """Classify hits and optionally decompose overlapping sounds."""
226
+ print(f"[Stage 3] Classifying & separating...")
227
+ all_hits, overlap_count = [], 0
228
+ band_labels = {"low": "bass_hit", "mid": "mid_hit", "high": "bright_hit"}
229
+
230
+ for hit in hits:
231
+ hit.label = classify_hit(hit)
232
+ if separate_overlaps:
233
+ bands = spectral_decompose(hit, overlap_threshold)
234
+ if len(bands) >= 2:
235
+ energies = {k: np.sqrt(np.mean(v**2)) for k, v in bands.items()}
236
+ mx = max(energies.values())
237
+ sig = {k: v for k, v in bands.items() if energies[k] > overlap_threshold * mx}
238
+ if len(sig) >= 2:
239
+ overlap_count += 1
240
+ for bn, ba in sig.items():
241
+ sc = float(librosa.feature.spectral_centroid(y=ba, sr=hit.sr).mean())
242
+ sub = Hit(audio=ba, sr=hit.sr, onset_time=hit.onset_time,
243
+ duration=hit.duration, index=len(all_hits),
244
+ rms_energy=float(np.sqrt(np.mean(ba**2))),
245
+ spectral_centroid=sc, label=band_labels.get(bn, "other"))
246
+ # Re-classify the sub-hit with full rules
247
+ sub.label = classify_hit(sub)
248
+ all_hits.append(sub)
249
+ continue
250
+ hit.index = len(all_hits)
251
+ all_hits.append(hit)
252
+
253
+ counts = defaultdict(int)
254
+ for h in all_hits:
255
+ counts[h.label] += 1
256
+ print(f" Overlaps decomposed: {overlap_count}")
257
+ print(f" Total hits: {len(all_hits)}")
258
+ for l, c in sorted(counts.items(), key=lambda x: -x[1]):
259
+ print(f" {l}: {c}")
260
+ return all_hits
261
+
262
+
263
+ # ─── Stage 4: Embedding & Clustering ─────────────────────────────────────────
264
+
265
+ def compute_embeddings(hits: list) -> np.ndarray:
266
+ """58-dim librosa feature embeddings."""
267
+ embs = []
268
+ for h in hits:
269
+ y, sr = h.audio, h.sr
270
+ ml = int(0.05 * sr)
271
+ if len(y) < ml:
272
+ y = np.pad(y, (0, ml - len(y)))
273
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
274
+ c = librosa.feature.spectral_centroid(y=y, sr=sr)
275
+ bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)
276
+ ro = librosa.feature.spectral_rolloff(y=y, sr=sr)
277
+ ct = librosa.feature.spectral_contrast(y=y, sr=sr, n_bands=4)
278
+ fl = librosa.feature.spectral_flatness(y=y)
279
+ zcr = librosa.feature.zero_crossing_rate(y=y)
280
+ rms = librosa.feature.rms(y=y)
281
+ oe = librosa.onset.onset_strength(y=y, sr=sr)
282
+ if len(oe) > 1:
283
+ oen = oe / (oe.max() + 1e-10)
284
+ af = [oen.mean(), oen.std(), float(np.argmax(oen))/len(oen), oen[-1]]
285
+ else:
286
+ af = [0,0,0,0]
287
+ f = np.concatenate([mfcc.mean(1), mfcc.std(1), [c.mean(), c.std()],
288
+ [bw.mean(), bw.std()], [ro.mean()], ct.mean(1),
289
+ [fl.mean()], [zcr.mean()], [rms.mean()], af, [h.duration]])
290
+ embs.append(f)
291
+ embs = np.array(embs, dtype=np.float32)
292
+ mu, std = embs.mean(0), embs.std(0) + 1e-8
293
+ return (embs - mu) / std
294
+
295
+
296
+ def cluster_hits(hits: list, embeddings: np.ndarray) -> list:
297
+ """Cluster by label group, then sub-cluster via silhouette-optimized KMeans."""
298
+ from sklearn.cluster import KMeans
299
+ from sklearn.metrics import silhouette_score
300
+ print(f"[Stage 4] Clustering...")
301
+
302
+ groups = defaultdict(list)
303
+ for i, h in enumerate(hits):
304
+ groups[h.label].append(i)
305
+
306
+ clusters = []
307
+ for label, indices in groups.items():
308
+ if len(indices) < 2:
309
+ clusters.append(Cluster(cluster_id=len(clusters), label=f"{label}_0",
310
+ hits=[hits[i] for i in indices]))
311
+ continue
312
+ ge = embeddings[indices]
313
+ mk = min(max(2, len(indices)//3), 15)
314
+ bk, bs = 1, -1
315
+ for k in range(2, mk+1):
316
+ try:
317
+ km = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
318
+ sl = km.fit_predict(ge)
319
+ s = silhouette_score(ge, sl)
320
+ if s > bs: bk, bs = k, s
321
+ except ValueError:
322
+ continue
323
+ if bk >= 2:
324
+ sl = KMeans(n_clusters=bk, random_state=42, n_init=10).fit_predict(ge)
325
+ else:
326
+ sl = np.zeros(len(indices), dtype=int)
327
+ for sid in range(max(sl)+1):
328
+ mask = sl == sid
329
+ mi = [indices[j] for j in range(len(indices)) if mask[j]]
330
+ clusters.append(Cluster(cluster_id=len(clusters), label=f"{label}_{sid}",
331
+ hits=[hits[i] for i in mi]))
332
+ print(f" {label}: {len(indices)} β†’ {bk} sub-clusters (sil={bs:.3f})")
333
+
334
+ print(f" βœ“ Total clusters: {len(clusters)}")
335
+ return clusters
336
+
337
+
338
+ # ─── Stage 5: Quality scoring & selection ─────────────────────────────────────
339
+
340
+ def sample_quality_score(y: np.ndarray, sr: int, label: str = "other") -> dict:
341
+ """Score a sample for production quality. Returns dict with total [0,100]."""
342
+ # Completeness
343
+ rms_env = librosa.feature.rms(y=y, frame_length=512, hop_length=128)[0]
344
+ if len(rms_env) >= 10:
345
+ pk = np.argmax(rms_env); post = rms_env[pk:]
346
+ tail_r = np.mean(post[-max(3, len(post)//5):]) / (rms_env[pk] + 1e-8)
347
+ c1 = max(0, 1.0 - tail_r * 5)
348
+ else:
349
+ c1 = 0.5
350
+ import scipy.stats
351
+ if len(rms_env) >= 10:
352
+ pk = np.argmax(rms_env); post = rms_env[pk:]
353
+ if len(post) >= 5:
354
+ slope, _, r, _, _ = scipy.stats.linregress(np.arange(len(post)), np.log(post+1e-8))
355
+ c2 = max(0, r**2) if slope < 0 else r**2 * 0.3
356
+ else:
357
+ c2 = 0.0
358
+ else:
359
+ c2 = 0.0
360
+ completeness = c1 * 0.6 + c2 * 0.4
361
+
362
+ # Cleanness: robust SNR + pre-onset energy
363
+ snr = 10*np.log10(np.percentile(y**2, 99) / (np.percentile(y**2, 10) + 1e-12))
364
+ n_snr = np.clip((snr - 10) / 40, 0, 1)
365
+ onsets = librosa.onset.onset_detect(y=y, sr=sr, units='samples', backtrack=True)
366
+ if len(onsets) > 0:
367
+ os_s = int(onsets[0])
368
+ pre = y[max(0, os_s-int(sr*.02)):os_s]
369
+ sig = y[os_s:os_s+int(sr*.1)]
370
+ if len(pre) > 10 and len(sig) > 10:
371
+ pdb = 10*np.log10(np.mean(pre**2+1e-12)/np.mean(sig**2+1e-12))
372
+ n_pre = np.clip((-pdb - 5) / 30, 0, 1)
373
+ else:
374
+ n_pre = 0.5
375
+ else:
376
+ n_pre = 0.5
377
+ cleanness = n_snr * 0.5 + n_pre * 0.5
378
+
379
+ # Onset quality
380
+ oe = librosa.onset.onset_strength(y=y, sr=sr)
381
+ sharpness = float(np.max(oe) / (np.mean(oe) + 1e-8)) if len(oe) > 1 else 1.0
382
+ onset_q = float(np.clip((sharpness - 1.0) / 5.0, 0, 1))
383
+
384
+ total = (completeness * 0.30 + cleanness * 0.40 + onset_q * 0.20 + 0.5 * 0.10) * 100
385
+ return {'total': float(total), 'completeness': float(completeness),
386
+ 'cleanness': float(cleanness), 'onset_quality': float(onset_q)}
387
+
388
+
389
+ def select_best(clusters: list):
390
+ """Select best representative per cluster using quality scoring."""
391
+ print(f"[Stage 5] Selecting best representatives...")
392
+ for c in clusters:
393
+ if c.count <= 1:
394
+ c.best_hit_idx = 0; continue
395
+ scores = [sample_quality_score(h.audio, h.sr, c.label.rsplit('_',1)[0])['total']
396
+ for h in c.hits]
397
+ c.best_hit_idx = int(np.argmax(scores))
398
+
399
+
400
+ # ─── Stage 6: Synthesis ──────────────────────────────────────────────────────
401
+
402
+ def synthesize_from_cluster(cluster: Cluster) -> Optional[np.ndarray]:
403
+ """Peak-aligned weighted average synthesis."""
404
+ if cluster.count < 2:
405
+ return None
406
+ tl = int(np.median([len(h.audio) for h in cluster.hits]))
407
+ aligned, weights = [], []
408
+ pp_target = None
409
+ for i, h in enumerate(cluster.hits):
410
+ a = h.audio.copy()
411
+ pp = np.argmax(np.abs(a))
412
+ if pp_target is None: pp_target = pp
413
+ shift = pp_target - pp
414
+ if shift > 0: a = np.pad(a, (shift, 0))
415
+ elif shift < 0: a = a[-shift:]
416
+ a = a[:tl] if len(a) >= tl else np.pad(a, (0, tl - len(a)))
417
+ pk = np.abs(a).max()
418
+ if pk > 0: a = a / pk
419
+ aligned.append(a)
420
+ weights.append(2.0 if i == cluster.best_hit_idx else 1.0)
421
+ aligned = np.array(aligned)
422
+ w = np.array(weights); w /= w.sum()
423
+ synth = np.average(aligned, axis=0, weights=w)
424
+ pk = np.abs(synth).max()
425
+ return (synth * 0.95 / pk).astype(np.float32) if pk > 0 else synth.astype(np.float32)
426
+
427
+
428
+ # ─── Stage 7: MIDI reconstruction ────────────────────────────────────────────
429
+
430
+ def build_midi(clusters: list, bpm: float = 120.0) -> 'pretty_midi.PrettyMIDI':
431
+ """Build MIDI file mapping each cluster to a unique note."""
432
+ import pretty_midi
433
+
434
+ pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
435
+
436
+ # Assign MIDI notes: one per cluster, starting at C2 (36)
437
+ base_note = 36
438
+ for i, c in enumerate(clusters):
439
+ c.midi_note = min(base_note + i, 127)
440
+
441
+ # Create one instrument for all (using Standard Drums channel for now)
442
+ inst = pretty_midi.Instrument(program=0, is_drum=True, name='Extracted Samples')
443
+ pm.instruments.append(inst)
444
+
445
+ for c in clusters:
446
+ for h in c.hits:
447
+ vel = max(1, min(127, int(h.rms_energy / 0.3 * 127)))
448
+ note = pretty_midi.Note(velocity=vel, pitch=c.midi_note,
449
+ start=h.onset_time,
450
+ end=h.onset_time + max(h.duration, 0.05))
451
+ inst.notes.append(note)
452
+
453
+ # Sort notes by start time
454
+ inst.notes.sort(key=lambda n: n.start)
455
+ return pm
456
+
457
+
458
+ def export_midi(clusters: list, output_path: str, bpm: float = 120.0):
459
+ """Export MIDI file."""
460
+ pm = build_midi(clusters, bpm)
461
+ pm.write(output_path)
462
+ print(f" βœ“ MIDI saved: {output_path} ({len(pm.instruments[0].notes)} notes)")
463
+ return pm
464
+
465
+
466
+ def build_sample_map(clusters: list) -> dict:
467
+ """Build a mapping from MIDI note β†’ cluster for DAW import."""
468
+ return {
469
+ c.midi_note: {
470
+ 'label': c.label,
471
+ 'count': c.count,
472
+ 'duration_ms': int(c.best_hit.duration * 1000),
473
+ }
474
+ for c in clusters
475
+ }
476
+
477
+
478
+ # ─── Main pipeline ───────────────────────────────────────────────────────────
479
+
480
+ def run_pipeline(
481
+ audio_path: str,
482
+ output_dir: str = "./extracted_samples",
483
+ stem: str = "drums", # drums|bass|vocals|other|all
484
+ device: str = "auto",
485
+ onset_mode: str = "auto", # auto|percussive|harmonic|broadband
486
+ separate_overlaps: bool = True,
487
+ overlap_threshold: float = 0.15,
488
+ synthesize: bool = True,
489
+ export_midi_file: bool = True,
490
+ bpm: float = 120.0,
491
+ min_dur: float = 0.02,
492
+ max_dur: float = 1.5,
493
+ energy_threshold_db: float = -45.0,
494
+ pre_pad: float = 0.005,
495
+ min_gap: float = 0.015,
496
+ save_intermediates: bool = True,
497
+ ) -> tuple:
498
+ """Run the full extraction pipeline. Returns (clusters, hits, midi_pm)."""
499
+ if device == "auto":
500
+ device = "cuda" if torch.cuda.is_available() else "cpu"
501
+
502
+ out = Path(output_dir); out.mkdir(parents=True, exist_ok=True)
503
+
504
+ # Stage 1
505
+ audio, sr = extract_stem(audio_path, stem=stem, device=device)
506
+ if save_intermediates:
507
+ sf.write(str(out / f"{stem}_stem.wav"), audio, sr, subtype='PCM_24')
508
+
509
+ # Stage 2
510
+ hits = detect_onsets(audio, sr, pre_pad=pre_pad, min_dur=min_dur,
511
+ max_dur=max_dur, min_gap=min_gap,
512
+ energy_threshold_db=energy_threshold_db, mode=onset_mode)
513
+ if not hits:
514
+ print("⚠ No hits detected!")
515
+ return [], [], None
516
+
517
+ # Stage 3
518
+ hits = classify_and_separate(hits, separate_overlaps=separate_overlaps,
519
+ overlap_threshold=overlap_threshold)
520
+
521
+ if save_intermediates:
522
+ hd = out / "all_hits"; hd.mkdir(exist_ok=True)
523
+ for h in hits:
524
+ h.save(str(hd / f"hit_{h.index:04d}_{h.label}_{h.onset_time:.3f}s.wav"))
525
+
526
+ # Stage 4
527
+ print(f"[Stage 4a] Computing embeddings...")
528
+ embs = compute_embeddings(hits)
529
+ print(f" βœ“ Embeddings: {embs.shape}")
530
+ for i, h in enumerate(hits): h.embedding = embs[i]
531
+ clusters = cluster_hits(hits, embs)
532
+
533
+ # Stage 5
534
+ select_best(clusters)
535
+
536
+ # Stage 6
537
+ if synthesize:
538
+ print(f"[Stage 6] Synthesizing...")
539
+ for c in clusters:
540
+ if c.count >= 2:
541
+ c.synthesized = synthesize_from_cluster(c)
542
+
543
+ # Stage 7: MIDI
544
+ midi_pm = None
545
+ if export_midi_file:
546
+ print(f"[Stage 7] Building MIDI reconstruction...")
547
+ midi_pm = export_midi(clusters, str(out / "reconstruction.mid"), bpm=bpm)
548
+ # Save sample map
549
+ smap = build_sample_map(clusters)
550
+ with open(str(out / "sample_map.json"), 'w') as f:
551
+ json.dump(smap, f, indent=2)
552
+ print(f" Sample map: {out / 'sample_map.json'}")
553
+
554
+ # Export
555
+ print(f"[Export] Saving samples...")
556
+ sd = out / "samples"; sd.mkdir(exist_ok=True)
557
+ if synthesize:
558
+ synd = out / "synthesized"; synd.mkdir(exist_ok=True)
559
+
560
+ manifest = []
561
+ for c in clusters:
562
+ best = c.best_hit
563
+ sp = sd / f"{c.label}__best.wav"; best.save(str(sp))
564
+ entry = {'cluster_id': c.cluster_id, 'label': c.label, 'count': c.count,
565
+ 'midi_note': c.midi_note, 'best_onset': best.onset_time,
566
+ 'best_duration': best.duration, 'best_energy': best.rms_energy}
567
+ if synthesize and c.synthesized is not None:
568
+ synp = synd / f"{c.label}__synthesized.wav"
569
+ sf.write(str(synp), c.synthesized, best.sr, subtype='PCM_24')
570
+ entry['synthesized'] = str(synp)
571
+ manifest.append(entry)
572
+ print(f" βœ“ {c.label}: {c.count} hits β†’ MIDI note {c.midi_note}")
573
+
574
+ with open(str(out / "manifest.json"), 'w') as f:
575
+ json.dump(manifest, f, indent=2)
576
+
577
+ print(f"\n{'='*50}")
578
+ print(f" Clusters: {len(clusters)}")
579
+ print(f" Total hits: {sum(c.count for c in clusters)}")
580
+ print(f" Output: {output_dir}")
581
+ return clusters, hits, midi_pm
582
+
583
+
584
+ # ─── CLI ──────────────────────────────────────────────────────────────────────
585
+
586
+ def main():
587
+ p = argparse.ArgumentParser(description="Extract audio samples from any audio file")
588
+ p.add_argument("input", help="Input audio file")
589
+ p.add_argument("-o", "--output-dir", default="./extracted_samples")
590
+ p.add_argument("--stem", default="drums", choices=["drums","bass","vocals","other","all"])
591
+ p.add_argument("--onset-mode", default="auto", choices=["auto","percussive","harmonic","broadband"])
592
+ p.add_argument("--no-gpu", action="store_true")
593
+ p.add_argument("--no-separate", action="store_true")
594
+ p.add_argument("--no-midi", action="store_true")
595
+ p.add_argument("--bpm", type=float, default=120.0)
596
+ p.add_argument("--min-dur", type=float, default=0.02)
597
+ p.add_argument("--max-dur", type=float, default=1.5)
598
+ p.add_argument("--energy-threshold", type=float, default=-45.0)
599
+ args = p.parse_args()
600
+
601
+ run_pipeline(audio_path=args.input, output_dir=args.output_dir,
602
+ stem=args.stem, device="cpu" if args.no_gpu else "auto",
603
+ onset_mode=args.onset_mode, separate_overlaps=not args.no_separate,
604
+ export_midi_file=not args.no_midi, bpm=args.bpm,
605
+ min_dur=args.min_dur, max_dur=args.max_dur,
606
+ energy_threshold_db=args.energy_threshold)
607
+
608
+ if __name__ == "__main__":
609
+ main()