rikhoffbauer2 commited on
Commit
f183819
Β·
verified Β·
1 Parent(s): f9c63b3

Add drum_extractor.py

Browse files
Files changed (1) hide show
  1. drum_extractor.py +900 -1
drum_extractor.py CHANGED
@@ -1 +1,900 @@
1
- {{DRUM_EXTRACTOR}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Drum Sample Extractor Pipeline
4
+ ===============================
5
+ Extracts individual drum samples from an audio file through:
6
+
7
+ 1. STEM SEPARATION β€” HTDemucs (v4 fine-tuned) isolates the drum track
8
+ 2. ONSET DETECTION β€” librosa detects individual hit boundaries
9
+ 3. INTRA-DRUM SEP β€” Spectral band splitting + optional AudioSep for overlapping sounds
10
+ 4. CLUSTERING β€” CLAP embeddings + auto-K KMeans groups identical hits
11
+ 5. SELECTION β€” Best representative per cluster (centroid-nearest + highest energy)
12
+ 6. SYNTHESIS (opt) β€” Weighted average of cluster members for an "ideal" sample
13
+
14
+ Usage:
15
+ python drum_extractor.py input.mp3 --output-dir ./samples
16
+ python drum_extractor.py input.wav --output-dir ./samples --no-gpu
17
+ python drum_extractor.py input.mp3 --output-dir ./samples --use-audiosep
18
+ """
19
+
20
+ import argparse
21
+ import json
22
+ import os
23
+ import sys
24
+ import warnings
25
+ from collections import defaultdict
26
+ from dataclasses import dataclass, field
27
+ from pathlib import Path
28
+ from typing import Optional
29
+
30
+ import librosa
31
+ import numpy as np
32
+ import soundfile as sf
33
+ import torch
34
+
35
+ warnings.filterwarnings("ignore", category=FutureWarning)
36
+ warnings.filterwarnings("ignore", category=UserWarning)
37
+
38
+
39
+ # ─────────────────────────────────────────────────────────────────────────────
40
+ # Data structures
41
+ # ─────────────────────────────────────────────────────────────────────────────
42
+
43
+ @dataclass
44
+ class DrumHit:
45
+ """A single detected drum hit."""
46
+ audio: np.ndarray # mono waveform
47
+ sr: int # sample rate
48
+ onset_time: float # onset time in seconds (in the drum stem)
49
+ duration: float # duration in seconds
50
+ index: int # sequential index
51
+ rms_energy: float = 0.0
52
+ spectral_centroid: float = 0.0
53
+ rough_label: str = "" # spectral rough label: kick/snare/hihat/other
54
+ embedding: Optional[np.ndarray] = None
55
+ cluster_id: int = -1
56
+
57
+ def save(self, path: str):
58
+ sf.write(path, self.audio, self.sr, subtype='PCM_24')
59
+
60
+
61
+ @dataclass
62
+ class DrumCluster:
63
+ """A cluster of similar drum hits."""
64
+ cluster_id: int
65
+ label: str # e.g. "kick_0", "snare_1"
66
+ hits: list = field(default_factory=list)
67
+ best_hit_idx: int = 0 # index into self.hits
68
+ synthesized: Optional[np.ndarray] = None
69
+
70
+ @property
71
+ def best_hit(self) -> DrumHit:
72
+ return self.hits[self.best_hit_idx]
73
+
74
+ @property
75
+ def count(self) -> int:
76
+ return len(self.hits)
77
+
78
+
79
+ # ─────────────────────────────────────────────────────────────────────────────
80
+ # Stage 1: Drum stem extraction via Demucs
81
+ # ─────────────────────────────────────────────────────────────────────────────
82
+
83
+ def extract_drums_demucs(audio_path: str, device: str = "cpu") -> tuple[np.ndarray, int]:
84
+ """Extract drum stem using HTDemucs v4 (fine-tuned)."""
85
+ from demucs.pretrained import get_model
86
+ from demucs.apply import apply_model
87
+
88
+ print("=" * 60)
89
+ print("STAGE 1: Extracting drum stem with HTDemucs")
90
+ print("=" * 60)
91
+
92
+ # Try htdemucs_ft first (better drums), fall back to htdemucs
93
+ for model_name in ["htdemucs_ft", "htdemucs"]:
94
+ try:
95
+ model = get_model(model_name)
96
+ print(f" Loaded model: {model_name}")
97
+ break
98
+ except Exception as e:
99
+ print(f" Could not load {model_name}: {e}")
100
+ else:
101
+ raise RuntimeError("Could not load any Demucs model")
102
+
103
+ model.eval()
104
+ model.to(device)
105
+ target_sr = model.samplerate # 44100
106
+
107
+ # Load audio using librosa (works without FFmpeg system libs)
108
+ # librosa returns (samples, sr) as mono by default; load as-is for channel control
109
+ import librosa as _lr
110
+ audio_np, sr = _lr.load(audio_path, sr=target_sr, mono=False)
111
+ # audio_np: (channels, samples) or (samples,) if mono
112
+ if audio_np.ndim == 1:
113
+ audio_np = np.stack([audio_np, audio_np]) # mono β†’ stereo
114
+ elif audio_np.shape[0] == 1:
115
+ audio_np = np.concatenate([audio_np, audio_np], axis=0)
116
+ elif audio_np.shape[0] > 2:
117
+ audio_np = audio_np[:2]
118
+ wav = torch.from_numpy(audio_np).float() # [2, T]
119
+
120
+ wav = wav.unsqueeze(0).to(device) # [1, 2, T]
121
+ print(f" Audio: {wav.shape[-1] / target_sr:.1f}s, {target_sr}Hz")
122
+
123
+ # Separate
124
+ with torch.no_grad():
125
+ sources = apply_model(model, wav, device=device, shifts=1,
126
+ split=True, overlap=0.25, progress=True)
127
+
128
+ # sources: [1, n_sources, 2, T]
129
+ stem_names = model.sources # e.g. ['drums', 'bass', 'other', 'vocals']
130
+ drums_idx = stem_names.index('drums')
131
+ drums_wav = sources[0, drums_idx] # [2, T]
132
+
133
+ # Convert to mono numpy
134
+ drums_mono = drums_wav.mean(dim=0).cpu().numpy()
135
+ print(f" βœ“ Extracted drums: {len(drums_mono) / target_sr:.1f}s")
136
+
137
+ return drums_mono, target_sr
138
+
139
+
140
+ # ─────────────────────────────────────────────────────────────────────────────
141
+ # Stage 2: Onset detection & hit segmentation
142
+ # ─────────────────────────────────────────────────────────────────────────────
143
+
144
+ def detect_onsets(y: np.ndarray, sr: int,
145
+ pre_pad: float = 0.005,
146
+ min_hit_dur: float = 0.03,
147
+ max_hit_dur: float = 0.8,
148
+ min_gap: float = 0.02,
149
+ energy_threshold_db: float = -40.0) -> list[DrumHit]:
150
+ """Detect drum hit onsets and segment into individual hits."""
151
+ print("\n" + "=" * 60)
152
+ print("STAGE 2: Detecting drum hit onsets")
153
+ print("=" * 60)
154
+
155
+ # Multi-band onset detection for better drum coverage
156
+ # Low band (kick): 20-250 Hz
157
+ # Mid band (snare/toms): 250-4000 Hz
158
+ # High band (cymbals): 4000+ Hz
159
+ onset_env_low = librosa.onset.onset_strength(
160
+ y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median
161
+ )
162
+ onset_env_mid = librosa.onset.onset_strength(
163
+ y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median
164
+ )
165
+ onset_env_high = librosa.onset.onset_strength(
166
+ y=y, sr=sr, fmin=4000, fmax=sr // 2, aggregate=np.median
167
+ )
168
+
169
+ # Combine: normalize each band, then take max across bands
170
+ def norm(x):
171
+ mx = x.max()
172
+ return x / mx if mx > 0 else x
173
+
174
+ onset_env = np.maximum(norm(onset_env_low),
175
+ np.maximum(norm(onset_env_mid), norm(onset_env_high)))
176
+
177
+ # Detect onsets
178
+ wait_frames = max(1, int(min_gap * sr / 512)) # hop_length=512 default
179
+ onsets_frames = librosa.onset.onset_detect(
180
+ onset_envelope=onset_env,
181
+ sr=sr,
182
+ wait=wait_frames,
183
+ pre_avg=3,
184
+ post_avg=3,
185
+ pre_max=3,
186
+ post_max=5,
187
+ backtrack=True,
188
+ units='frames'
189
+ )
190
+ onset_times = librosa.frames_to_time(onsets_frames, sr=sr)
191
+
192
+ print(f" Raw onsets detected: {len(onset_times)}")
193
+
194
+ # Segment into hits
195
+ hits = []
196
+ energy_threshold = 10 ** (energy_threshold_db / 20)
197
+
198
+ for i, t in enumerate(onset_times):
199
+ start_sample = max(0, int((t - pre_pad) * sr))
200
+
201
+ # End = next onset or max_hit_dur, whichever is shorter
202
+ if i + 1 < len(onset_times):
203
+ next_onset_sample = int(onset_times[i + 1] * sr)
204
+ end_sample = min(next_onset_sample, start_sample + int(max_hit_dur * sr))
205
+ else:
206
+ end_sample = min(len(y), start_sample + int(max_hit_dur * sr))
207
+
208
+ segment = y[start_sample:end_sample]
209
+
210
+ # Skip too-short or too-quiet hits
211
+ if len(segment) < int(min_hit_dur * sr):
212
+ continue
213
+ rms = np.sqrt(np.mean(segment ** 2))
214
+ if rms < energy_threshold:
215
+ continue
216
+
217
+ # Apply a quick fade-out to avoid clicks
218
+ fade_len = min(int(0.005 * sr), len(segment) // 4)
219
+ if fade_len > 0:
220
+ segment = segment.copy()
221
+ segment[-fade_len:] *= np.linspace(1, 0, fade_len)
222
+
223
+ # Compute features
224
+ spectral_centroid = float(librosa.feature.spectral_centroid(
225
+ y=segment, sr=sr
226
+ ).mean())
227
+
228
+ hit = DrumHit(
229
+ audio=segment,
230
+ sr=sr,
231
+ onset_time=t,
232
+ duration=len(segment) / sr,
233
+ index=len(hits),
234
+ rms_energy=float(rms),
235
+ spectral_centroid=spectral_centroid,
236
+ )
237
+ hits.append(hit)
238
+
239
+ print(f" βœ“ Valid hits after filtering: {len(hits)}")
240
+ return hits
241
+
242
+
243
+ # ─────────────────────────────────────────────────────────────────────────────
244
+ # Stage 3: Rough spectral classification + optional intra-drum separation
245
+ # ─────────────────────────────────────────────────────────────────────────────
246
+
247
+ def rough_spectral_label(hit: DrumHit) -> str:
248
+ """Assign a rough drum type label based on spectral characteristics."""
249
+ y, sr = hit.audio, hit.sr
250
+
251
+ # Spectral centroid (mean frequency)
252
+ centroid = hit.spectral_centroid
253
+
254
+ # Energy distribution across bands
255
+ D = np.abs(librosa.stft(y, n_fft=2048))
256
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
257
+
258
+ low_energy = np.sum(D[(freqs >= 20) & (freqs < 200)] ** 2)
259
+ mid_energy = np.sum(D[(freqs >= 200) & (freqs < 4000)] ** 2)
260
+ high_energy = np.sum(D[(freqs >= 4000)] ** 2)
261
+ total = low_energy + mid_energy + high_energy + 1e-10
262
+
263
+ low_ratio = low_energy / total
264
+ mid_ratio = mid_energy / total
265
+ high_ratio = high_energy / total
266
+
267
+ # Zero crossing rate (percussive = high)
268
+ zcr = float(librosa.feature.zero_crossing_rate(y=y).mean())
269
+
270
+ # Decision tree
271
+ if low_ratio > 0.5 and centroid < 800:
272
+ return "kick"
273
+ elif high_ratio > 0.35 and centroid > 4000:
274
+ if hit.duration < 0.15:
275
+ return "hihat_closed"
276
+ else:
277
+ return "hihat_open"
278
+ elif high_ratio > 0.25 and centroid > 3000:
279
+ return "cymbal"
280
+ elif mid_ratio > 0.4 and zcr > 0.1 and centroid > 1000:
281
+ return "snare"
282
+ elif low_ratio > 0.3 and mid_ratio > 0.3:
283
+ return "tom"
284
+ elif centroid > 2500:
285
+ return "perc_high"
286
+ else:
287
+ return "perc_low"
288
+
289
+
290
+ def spectral_separate_hit(hit: DrumHit) -> dict[str, np.ndarray]:
291
+ """
292
+ Decompose a single hit into spectral bands.
293
+ Returns dict of {band_name: audio_array}.
294
+ Useful for hits where multiple drums overlap.
295
+ """
296
+ y, sr = hit.audio, hit.sr
297
+ D = librosa.stft(y, n_fft=2048)
298
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
299
+
300
+ bands = {
301
+ "low": (20, 250), # kick range
302
+ "mid": (250, 4000), # snare/tom range
303
+ "high": (4000, sr // 2) # hihat/cymbal range
304
+ }
305
+
306
+ results = {}
307
+ for name, (fmin, fmax) in bands.items():
308
+ mask = (freqs >= fmin) & (freqs <= fmax)
309
+ D_band = np.zeros_like(D)
310
+ D_band[mask] = D[mask]
311
+ audio_band = librosa.istft(D_band, length=len(y))
312
+
313
+ # Only include if there's meaningful energy
314
+ if np.sqrt(np.mean(audio_band ** 2)) > 0.001:
315
+ results[name] = audio_band
316
+
317
+ return results
318
+
319
+
320
+ def classify_and_separate_hits(hits: list[DrumHit],
321
+ separate_overlaps: bool = True) -> list[DrumHit]:
322
+ """Classify hits and optionally split overlapping sounds into sub-hits."""
323
+ print("\n" + "=" * 60)
324
+ print("STAGE 3: Spectral classification & separation")
325
+ print("=" * 60)
326
+
327
+ all_hits = []
328
+ overlap_count = 0
329
+
330
+ for hit in hits:
331
+ label = rough_spectral_label(hit)
332
+ hit.rough_label = label
333
+
334
+ if separate_overlaps:
335
+ # Check if multiple bands have significant energy (= overlap)
336
+ bands = spectral_separate_hit(hit)
337
+ if len(bands) >= 2:
338
+ # Check if the sub-bands are meaningfully different
339
+ energies = {k: np.sqrt(np.mean(v ** 2)) for k, v in bands.items()}
340
+ max_e = max(energies.values())
341
+ significant = {k: v for k, v in bands.items()
342
+ if energies[k] > 0.15 * max_e}
343
+
344
+ if len(significant) >= 2:
345
+ overlap_count += 1
346
+ # Create sub-hits for each significant band
347
+ band_labels = {"low": "kick", "mid": "snare", "high": "hihat"}
348
+ for band_name, band_audio in significant.items():
349
+ sub_hit = DrumHit(
350
+ audio=band_audio,
351
+ sr=hit.sr,
352
+ onset_time=hit.onset_time,
353
+ duration=hit.duration,
354
+ index=len(all_hits),
355
+ rms_energy=float(np.sqrt(np.mean(band_audio ** 2))),
356
+ spectral_centroid=float(librosa.feature.spectral_centroid(
357
+ y=band_audio, sr=hit.sr
358
+ ).mean()),
359
+ rough_label=band_labels.get(band_name, "other"),
360
+ )
361
+ all_hits.append(sub_hit)
362
+ continue # skip adding the original
363
+
364
+ hit.index = len(all_hits)
365
+ all_hits.append(hit)
366
+
367
+ label_counts = defaultdict(int)
368
+ for h in all_hits:
369
+ label_counts[h.rough_label] += 1
370
+
371
+ print(f" Overlapping hits decomposed: {overlap_count}")
372
+ print(f" Total hits after separation: {len(all_hits)}")
373
+ print(f" Label distribution:")
374
+ for label, count in sorted(label_counts.items(), key=lambda x: -x[1]):
375
+ print(f" {label}: {count}")
376
+
377
+ return all_hits
378
+
379
+
380
+ # ─────────────────────────────────────────────────────────────────────────────
381
+ # Stage 4: Embedding & Clustering
382
+ # ──────────────────────��──────────────────────────────────────────────────────
383
+
384
+ def compute_librosa_embeddings(hits: list[DrumHit]) -> np.ndarray:
385
+ """Compute rich librosa feature embeddings for all hits."""
386
+ embeddings = []
387
+ for hit in hits:
388
+ y, sr = hit.audio, hit.sr
389
+
390
+ # Pad very short audio
391
+ min_len = int(0.05 * sr)
392
+ if len(y) < min_len:
393
+ y = np.pad(y, (0, min_len - len(y)))
394
+
395
+ # MFCCs (timbre)
396
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
397
+ mfcc_mean = mfcc.mean(axis=1)
398
+ mfcc_std = mfcc.std(axis=1)
399
+
400
+ # Spectral features
401
+ centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
402
+ bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)
403
+ rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
404
+ contrast = librosa.feature.spectral_contrast(y=y, sr=sr, n_bands=4)
405
+ flatness = librosa.feature.spectral_flatness(y=y)
406
+
407
+ # Temporal features
408
+ zcr = librosa.feature.zero_crossing_rate(y=y)
409
+ rms = librosa.feature.rms(y=y)
410
+
411
+ # Onset strength envelope shape
412
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
413
+ if len(onset_env) > 1:
414
+ onset_env_norm = onset_env / (onset_env.max() + 1e-10)
415
+ # Attack/decay shape: first 4 moments
416
+ attack_feats = [
417
+ onset_env_norm.mean(),
418
+ onset_env_norm.std(),
419
+ float(np.argmax(onset_env_norm)) / len(onset_env_norm), # peak position
420
+ onset_env_norm[-1] if len(onset_env_norm) > 0 else 0, # tail energy
421
+ ]
422
+ else:
423
+ attack_feats = [0, 0, 0, 0]
424
+
425
+ # Assemble feature vector
426
+ feat = np.concatenate([
427
+ mfcc_mean, # 20
428
+ mfcc_std, # 20
429
+ [centroid.mean(), centroid.std()], # 2
430
+ [bandwidth.mean(), bandwidth.std()], # 2
431
+ [rolloff.mean()], # 1
432
+ contrast.mean(axis=1), # 5
433
+ [flatness.mean()], # 1
434
+ [zcr.mean()], # 1
435
+ [rms.mean()], # 1
436
+ attack_feats, # 4
437
+ [hit.duration], # 1
438
+ ])
439
+ embeddings.append(feat)
440
+
441
+ embeddings = np.array(embeddings, dtype=np.float32)
442
+
443
+ # Normalize features (z-score per dimension)
444
+ mean = embeddings.mean(axis=0)
445
+ std = embeddings.std(axis=0) + 1e-8
446
+ embeddings = (embeddings - mean) / std
447
+
448
+ return embeddings
449
+
450
+
451
+ def compute_clap_embeddings(hits: list[DrumHit], device: str = "cpu") -> np.ndarray:
452
+ """Compute CLAP audio embeddings (semantic, 512-dim)."""
453
+ from transformers import ClapModel, ClapProcessor
454
+
455
+ print(" Loading CLAP model (laion/larger_clap_general)...")
456
+ model = ClapModel.from_pretrained("laion/larger_clap_general").to(device)
457
+ processor = ClapProcessor.from_pretrained("laion/larger_clap_general")
458
+ model.eval()
459
+
460
+ clap_sr = 48000
461
+ embeddings = []
462
+
463
+ for i, hit in enumerate(hits):
464
+ # Resample to 48kHz for CLAP
465
+ y_48k = librosa.resample(hit.audio, orig_sr=hit.sr, target_sr=clap_sr)
466
+
467
+ # Pad short audio to at least 0.5s
468
+ min_samples = int(0.5 * clap_sr)
469
+ if len(y_48k) < min_samples:
470
+ y_48k = np.pad(y_48k, (0, min_samples - len(y_48k)))
471
+
472
+ inputs = processor(audios=y_48k, sampling_rate=clap_sr, return_tensors="pt")
473
+ inputs = {k: v.to(device) for k, v in inputs.items()}
474
+
475
+ with torch.no_grad():
476
+ audio_embed = model.get_audio_features(**inputs)
477
+ embeddings.append(audio_embed.squeeze().cpu().numpy())
478
+
479
+ if (i + 1) % 50 == 0:
480
+ print(f" Embedded {i + 1}/{len(hits)}")
481
+
482
+ return np.array(embeddings, dtype=np.float32)
483
+
484
+
485
+ def cluster_hits(hits: list[DrumHit],
486
+ embeddings: np.ndarray,
487
+ min_clusters: int = 2,
488
+ max_clusters: int = 30) -> list[DrumCluster]:
489
+ """Cluster hits by embedding similarity, auto-selecting K."""
490
+ from sklearn.cluster import KMeans
491
+ from sklearn.metrics import silhouette_score
492
+
493
+ print("\n" + "=" * 60)
494
+ print("STAGE 4: Clustering similar drum hits")
495
+ print("=" * 60)
496
+
497
+ n = len(hits)
498
+ max_clusters = min(max_clusters, n - 1)
499
+ if max_clusters < min_clusters:
500
+ max_clusters = min_clusters
501
+
502
+ # First cluster by rough label, then sub-cluster within each group
503
+ label_groups = defaultdict(list)
504
+ for i, hit in enumerate(hits):
505
+ label_groups[hit.rough_label].append(i)
506
+
507
+ all_clusters = []
508
+
509
+ for label, indices in label_groups.items():
510
+ if len(indices) < 2:
511
+ # Single-hit group β†’ its own cluster
512
+ cluster = DrumCluster(
513
+ cluster_id=len(all_clusters),
514
+ label=f"{label}_0",
515
+ hits=[hits[i] for i in indices]
516
+ )
517
+ all_clusters.append(cluster)
518
+ continue
519
+
520
+ # Sub-cluster within this label group
521
+ group_embeddings = embeddings[indices]
522
+
523
+ # Auto-select k via silhouette score
524
+ max_k = min(max(2, len(indices) // 3), 15)
525
+ best_k, best_score = 1, -1
526
+
527
+ for k in range(2, max_k + 1):
528
+ try:
529
+ km = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
530
+ sub_labels = km.fit_predict(group_embeddings)
531
+ score = silhouette_score(group_embeddings, sub_labels)
532
+ if score > best_score:
533
+ best_k, best_score = k, score
534
+ except ValueError:
535
+ continue
536
+
537
+ # Fit with best k
538
+ if best_k >= 2:
539
+ km = KMeans(n_clusters=best_k, random_state=42, n_init=10)
540
+ sub_labels = km.fit_predict(group_embeddings)
541
+ else:
542
+ sub_labels = np.zeros(len(indices), dtype=int)
543
+
544
+ # Build clusters
545
+ for sub_id in range(max(sub_labels) + 1):
546
+ member_mask = sub_labels == sub_id
547
+ member_indices = [indices[j] for j in range(len(indices)) if member_mask[j]]
548
+
549
+ cluster = DrumCluster(
550
+ cluster_id=len(all_clusters),
551
+ label=f"{label}_{sub_id}",
552
+ hits=[hits[i] for i in member_indices],
553
+ )
554
+ all_clusters.append(cluster)
555
+
556
+ print(f" {label}: {len(indices)} hits β†’ {best_k} sub-clusters "
557
+ f"(silhouette={best_score:.3f})")
558
+
559
+ print(f"\n βœ“ Total clusters: {len(all_clusters)}")
560
+ for c in all_clusters:
561
+ print(f" {c.label}: {c.count} hits")
562
+
563
+ return all_clusters
564
+
565
+
566
+ # ─────────────────────────────────────────────────────────────────────────────
567
+ # Stage 5: Best representative selection
568
+ # ─────────────────────────────────────────────────────────────────────────────
569
+
570
+ def select_best_representatives(clusters: list[DrumCluster],
571
+ embeddings_dict: dict = None):
572
+ """Select the best representative hit from each cluster."""
573
+ print("\n" + "=" * 60)
574
+ print("STAGE 5: Selecting best representatives")
575
+ print("=" * 60)
576
+
577
+ for cluster in clusters:
578
+ if cluster.count == 1:
579
+ cluster.best_hit_idx = 0
580
+ continue
581
+
582
+ # Strategy: combine centroid-distance + energy + short duration preference
583
+ # We want a clean, loud, representative hit
584
+
585
+ # Compute per-hit feature vectors for within-cluster comparison
586
+ hit_features = []
587
+ for hit in cluster.hits:
588
+ feat = np.concatenate([
589
+ librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1),
590
+ [hit.rms_energy, hit.spectral_centroid, hit.duration]
591
+ ])
592
+ hit_features.append(feat)
593
+ hit_features = np.array(hit_features)
594
+
595
+ # Normalize
596
+ mean = hit_features.mean(axis=0)
597
+ std = hit_features.std(axis=0) + 1e-8
598
+ hit_features_norm = (hit_features - mean) / std
599
+
600
+ # Centroid distance (representativeness)
601
+ centroid = hit_features_norm.mean(axis=0)
602
+ centroid_dists = np.linalg.norm(hit_features_norm - centroid, axis=1)
603
+ centroid_scores = 1.0 - (centroid_dists / (centroid_dists.max() + 1e-8))
604
+
605
+ # Energy score (prefer louder = cleaner)
606
+ energies = np.array([h.rms_energy for h in cluster.hits])
607
+ energy_scores = energies / (energies.max() + 1e-8)
608
+
609
+ # Combined score
610
+ scores = 0.6 * centroid_scores + 0.4 * energy_scores
611
+ cluster.best_hit_idx = int(np.argmax(scores))
612
+
613
+ print(f" {cluster.label}: selected hit {cluster.best_hit_idx} "
614
+ f"(score={scores[cluster.best_hit_idx]:.3f}, "
615
+ f"energy={cluster.hits[cluster.best_hit_idx].rms_energy:.4f})")
616
+
617
+
618
+ # ─────────────────────────────────────────────────────────────────────────────
619
+ # Stage 6 (optional): Synthesize optimal sample from cluster
620
+ # ─────────────────────────────────────────────────────────────────────────────
621
+
622
+ def synthesize_from_cluster(cluster: DrumCluster) -> np.ndarray:
623
+ """
624
+ Synthesize an 'optimal' sample by weighted-averaging cluster members.
625
+
626
+ Strategy: align samples to their peak, normalize lengths, then do a
627
+ weighted average in the time domain (weighted by similarity to centroid).
628
+ This reduces noise/bleed while preserving the core transient.
629
+ """
630
+ if cluster.count == 1:
631
+ return cluster.hits[0].audio.copy()
632
+
633
+ sr = cluster.hits[0].sr
634
+
635
+ # Find max length and peak positions
636
+ max_len = max(len(h.audio) for h in cluster.hits)
637
+ target_len = int(np.median([len(h.audio) for h in cluster.hits]))
638
+
639
+ # Align all hits to their peak (transient alignment)
640
+ aligned = []
641
+ weights = []
642
+ peak_pos_target = None
643
+
644
+ for i, hit in enumerate(cluster.hits):
645
+ audio = hit.audio.copy()
646
+ peak_pos = np.argmax(np.abs(audio))
647
+
648
+ if peak_pos_target is None:
649
+ peak_pos_target = peak_pos
650
+
651
+ # Shift to align peaks, then force exact target_len
652
+ shift = peak_pos_target - peak_pos
653
+ if shift > 0:
654
+ audio = np.pad(audio, (shift, 0))
655
+ elif shift < 0:
656
+ audio = audio[-shift:]
657
+
658
+ # Force exact length
659
+ if len(audio) >= target_len:
660
+ audio = audio[:target_len]
661
+ else:
662
+ audio = np.pad(audio, (0, target_len - len(audio)))
663
+
664
+ # Normalize amplitude
665
+ peak = np.abs(audio).max()
666
+ if peak > 0:
667
+ audio = audio / peak
668
+
669
+ aligned.append(audio)
670
+
671
+ # Weight by similarity to best hit (closer = higher weight)
672
+ if i == cluster.best_hit_idx:
673
+ weights.append(2.0) # double weight for the best sample
674
+ else:
675
+ weights.append(1.0)
676
+
677
+ # Weighted average
678
+ aligned = np.array(aligned)
679
+ weights = np.array(weights)
680
+ weights = weights / weights.sum()
681
+
682
+ synthesized = np.average(aligned, axis=0, weights=weights)
683
+
684
+ # Normalize output
685
+ peak = np.abs(synthesized).max()
686
+ if peak > 0:
687
+ synthesized = synthesized * (0.95 / peak)
688
+
689
+ return synthesized
690
+
691
+
692
+ # ─────────────────────────────────────────────────────────────────────────────
693
+ # Main pipeline
694
+ # ─────────────────────────────────────────────────────────────────────────────
695
+
696
+ def run_pipeline(
697
+ audio_path: str,
698
+ output_dir: str = "./drum_samples",
699
+ use_gpu: bool = True,
700
+ use_clap: bool = False, # CLAP embeddings (slower, semantic)
701
+ use_audiosep: bool = False, # AudioSep for overlap separation
702
+ separate_overlaps: bool = True,
703
+ synthesize: bool = True,
704
+ min_hit_dur: float = 0.03,
705
+ max_hit_dur: float = 0.8,
706
+ energy_threshold_db: float = -40.0,
707
+ save_intermediates: bool = True,
708
+ ):
709
+ """Run the full drum sample extraction pipeline."""
710
+ device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
711
+ print(f"Device: {device}")
712
+ print(f"Input: {audio_path}")
713
+ print(f"Output: {output_dir}")
714
+
715
+ output_dir = Path(output_dir)
716
+ output_dir.mkdir(parents=True, exist_ok=True)
717
+
718
+ # ── Stage 1: Extract drums ──
719
+ drums_audio, drums_sr = extract_drums_demucs(audio_path, device=device)
720
+
721
+ if save_intermediates:
722
+ drums_path = output_dir / "drums_stem.wav"
723
+ sf.write(str(drums_path), drums_audio, drums_sr, subtype='PCM_24')
724
+ print(f" Saved drum stem: {drums_path}")
725
+
726
+ # ── Stage 2: Detect onsets & segment ──
727
+ hits = detect_onsets(
728
+ drums_audio, drums_sr,
729
+ min_hit_dur=min_hit_dur,
730
+ max_hit_dur=max_hit_dur,
731
+ energy_threshold_db=energy_threshold_db,
732
+ )
733
+
734
+ if len(hits) == 0:
735
+ print("\n⚠ No drum hits detected! Try lowering energy_threshold_db.")
736
+ return
737
+
738
+ # ── Stage 3: Classify & optionally separate overlaps ──
739
+ hits = classify_and_separate_hits(hits, separate_overlaps=separate_overlaps)
740
+
741
+ if save_intermediates:
742
+ hits_dir = output_dir / "all_hits"
743
+ hits_dir.mkdir(exist_ok=True)
744
+ for hit in hits:
745
+ hit_path = hits_dir / f"hit_{hit.index:04d}_{hit.rough_label}_{hit.onset_time:.3f}s.wav"
746
+ hit.save(str(hit_path))
747
+
748
+ # ── Stage 4: Embed & cluster ──
749
+ print("\n" + "=" * 60)
750
+ print("STAGE 4a: Computing embeddings")
751
+ print("=" * 60)
752
+
753
+ if use_clap:
754
+ embeddings = compute_clap_embeddings(hits, device=device)
755
+ print(f" βœ“ CLAP embeddings: {embeddings.shape}")
756
+ else:
757
+ embeddings = compute_librosa_embeddings(hits)
758
+ print(f" βœ“ Librosa embeddings: {embeddings.shape}")
759
+
760
+ for i, hit in enumerate(hits):
761
+ hit.embedding = embeddings[i]
762
+
763
+ clusters = cluster_hits(hits, embeddings)
764
+
765
+ # ── Stage 5: Select best representatives ──
766
+ select_best_representatives(clusters)
767
+
768
+ # ── Stage 6: Optional synthesis ──
769
+ if synthesize:
770
+ print("\n" + "=" * 60)
771
+ print("STAGE 6: Synthesizing optimal samples")
772
+ print("=" * 60)
773
+ for cluster in clusters:
774
+ if cluster.count >= 2:
775
+ cluster.synthesized = synthesize_from_cluster(cluster)
776
+ print(f" {cluster.label}: synthesized from {cluster.count} hits")
777
+
778
+ # ── Export ──
779
+ print("\n" + "=" * 60)
780
+ print("EXPORT: Saving results")
781
+ print("=" * 60)
782
+
783
+ samples_dir = output_dir / "samples"
784
+ samples_dir.mkdir(exist_ok=True)
785
+
786
+ if synthesize:
787
+ synth_dir = output_dir / "synthesized"
788
+ synth_dir.mkdir(exist_ok=True)
789
+
790
+ manifest = []
791
+ for cluster in clusters:
792
+ best = cluster.best_hit
793
+
794
+ # Save best representative
795
+ sample_name = f"{cluster.label}__best.wav"
796
+ sample_path = samples_dir / sample_name
797
+ best.save(str(sample_path))
798
+
799
+ entry = {
800
+ "cluster_id": cluster.cluster_id,
801
+ "label": cluster.label,
802
+ "count": cluster.count,
803
+ "best_sample": str(sample_path),
804
+ "best_onset_time": best.onset_time,
805
+ "best_duration": best.duration,
806
+ "best_rms_energy": best.rms_energy,
807
+ "best_spectral_centroid": best.spectral_centroid,
808
+ }
809
+
810
+ # Save synthesized version
811
+ if synthesize and cluster.synthesized is not None:
812
+ synth_name = f"{cluster.label}__synthesized.wav"
813
+ synth_path = synth_dir / synth_name
814
+ sf.write(str(synth_path), cluster.synthesized, best.sr, subtype='PCM_24')
815
+ entry["synthesized_sample"] = str(synth_path)
816
+
817
+ manifest.append(entry)
818
+ print(f" βœ“ {cluster.label}: {cluster.count} hits β†’ {sample_path.name}")
819
+
820
+ # Save manifest
821
+ manifest_path = output_dir / "manifest.json"
822
+ with open(manifest_path, "w") as f:
823
+ json.dump(manifest, f, indent=2)
824
+ print(f"\n Manifest saved: {manifest_path}")
825
+
826
+ # Summary
827
+ print("\n" + "=" * 60)
828
+ print("SUMMARY")
829
+ print("=" * 60)
830
+ print(f" Input: {audio_path}")
831
+ print(f" Drum stem: {output_dir / 'drums_stem.wav'}")
832
+ print(f" Total hits: {len(hits)}")
833
+ print(f" Clusters: {len(clusters)}")
834
+ print(f" Samples saved: {samples_dir}")
835
+ if synthesize:
836
+ print(f" Synthesized: {synth_dir}")
837
+ print(f" Manifest: {manifest_path}")
838
+
839
+ return clusters
840
+
841
+
842
+ # ─────────────────────────────────────────────────────────────────────────────
843
+ # CLI
844
+ # ─────────────────────────────────────────────────────────────────────────────
845
+
846
+ def main():
847
+ parser = argparse.ArgumentParser(
848
+ description="Extract individual drum samples from an audio file",
849
+ formatter_class=argparse.RawDescriptionHelpFormatter,
850
+ epilog="""
851
+ Examples:
852
+ %(prog)s song.mp3 -o ./my_samples
853
+ %(prog)s drums.wav -o ./samples --no-gpu
854
+ %(prog)s song.wav -o ./samples --clap # Use CLAP for semantic clustering
855
+ %(prog)s song.wav -o ./samples --no-separate # Don't decompose overlaps
856
+ %(prog)s song.wav -o ./samples --no-synthesize # Skip synthesis step
857
+ """
858
+ )
859
+ parser.add_argument("input", help="Input audio file (mp3, wav, flac, etc.)")
860
+ parser.add_argument("-o", "--output-dir", default="./drum_samples",
861
+ help="Output directory (default: ./drum_samples)")
862
+ parser.add_argument("--no-gpu", action="store_true",
863
+ help="Force CPU-only processing")
864
+ parser.add_argument("--clap", action="store_true",
865
+ help="Use CLAP embeddings for clustering (slower, more semantic)")
866
+ parser.add_argument("--no-separate", action="store_true",
867
+ help="Don't separate overlapping drum sounds")
868
+ parser.add_argument("--no-synthesize", action="store_true",
869
+ help="Don't synthesize optimal samples from clusters")
870
+ parser.add_argument("--no-intermediates", action="store_true",
871
+ help="Don't save intermediate files (drum stem, individual hits)")
872
+ parser.add_argument("--min-hit-dur", type=float, default=0.03,
873
+ help="Minimum hit duration in seconds (default: 0.03)")
874
+ parser.add_argument("--max-hit-dur", type=float, default=0.8,
875
+ help="Maximum hit duration in seconds (default: 0.8)")
876
+ parser.add_argument("--energy-threshold", type=float, default=-40.0,
877
+ help="Energy threshold in dB for hit filtering (default: -40)")
878
+
879
+ args = parser.parse_args()
880
+
881
+ if not os.path.exists(args.input):
882
+ print(f"Error: Input file not found: {args.input}")
883
+ sys.exit(1)
884
+
885
+ run_pipeline(
886
+ audio_path=args.input,
887
+ output_dir=args.output_dir,
888
+ use_gpu=not args.no_gpu,
889
+ use_clap=args.clap,
890
+ separate_overlaps=not args.no_separate,
891
+ synthesize=not args.no_synthesize,
892
+ min_hit_dur=args.min_hit_dur,
893
+ max_hit_dur=args.max_hit_dur,
894
+ energy_threshold_db=args.energy_threshold,
895
+ save_intermediates=not args.no_intermediates,
896
+ )
897
+
898
+
899
+ if __name__ == "__main__":
900
+ main()