| |
| """ |
| Drum Sample Extractor Pipeline |
| =============================== |
| Extracts individual drum samples from an audio file through: |
| |
| 1. STEM SEPARATION β HTDemucs (v4 fine-tuned) isolates the drum track |
| 2. ONSET DETECTION β librosa detects individual hit boundaries |
| 3. INTRA-DRUM SEP β Spectral band splitting + optional AudioSep for overlapping sounds |
| 4. CLUSTERING β CLAP embeddings + auto-K KMeans groups identical hits |
| 5. SELECTION β Best representative per cluster (centroid-nearest + highest energy) |
| 6. SYNTHESIS (opt) β Weighted average of cluster members for an "ideal" sample |
| |
| Usage: |
| python drum_extractor.py input.mp3 --output-dir ./samples |
| python drum_extractor.py input.wav --output-dir ./samples --no-gpu |
| python drum_extractor.py input.mp3 --output-dir ./samples --clap |
| |
| Dependencies: |
| pip install demucs librosa soundfile scikit-learn numpy torch transformers |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import warnings |
| from collections import defaultdict |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import librosa |
| import numpy as np |
| import soundfile as sf |
| import torch |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class DrumHit: |
| """A single detected drum hit.""" |
| audio: np.ndarray |
| sr: int |
| onset_time: float |
| duration: float |
| index: int |
| rms_energy: float = 0.0 |
| spectral_centroid: float = 0.0 |
| rough_label: str = "" |
| embedding: Optional[np.ndarray] = None |
| cluster_id: int = -1 |
|
|
| def save(self, path: str): |
| sf.write(path, self.audio, self.sr, subtype='PCM_24') |
|
|
|
|
| @dataclass |
| class DrumCluster: |
| """A cluster of similar drum hits.""" |
| cluster_id: int |
| label: str |
| hits: list = field(default_factory=list) |
| best_hit_idx: int = 0 |
| synthesized: Optional[np.ndarray] = None |
|
|
| @property |
| def best_hit(self) -> DrumHit: |
| return self.hits[self.best_hit_idx] |
|
|
| @property |
| def count(self) -> int: |
| return len(self.hits) |
|
|
|
|
| |
| |
| |
|
|
| def extract_drums_demucs(audio_path: str, device: str = "cpu") -> tuple[np.ndarray, int]: |
| """Extract drum stem using HTDemucs v4 (fine-tuned).""" |
| from demucs.pretrained import get_model |
| from demucs.apply import apply_model |
|
|
| print("=" * 60) |
| print("STAGE 1: Extracting drum stem with HTDemucs") |
| print("=" * 60) |
|
|
| |
| for model_name in ["htdemucs_ft", "htdemucs"]: |
| try: |
| model = get_model(model_name) |
| print(f" Loaded model: {model_name}") |
| break |
| except Exception as e: |
| print(f" Could not load {model_name}: {e}") |
| else: |
| raise RuntimeError("Could not load any Demucs model") |
|
|
| model.eval() |
| model.to(device) |
| target_sr = model.samplerate |
|
|
| |
| audio_np, sr = librosa.load(audio_path, sr=target_sr, mono=False) |
| if audio_np.ndim == 1: |
| audio_np = np.stack([audio_np, audio_np]) |
| elif audio_np.shape[0] == 1: |
| audio_np = np.concatenate([audio_np, audio_np], axis=0) |
| elif audio_np.shape[0] > 2: |
| audio_np = audio_np[:2] |
| wav = torch.from_numpy(audio_np).float() |
|
|
| wav = wav.unsqueeze(0).to(device) |
| print(f" Audio: {wav.shape[-1] / target_sr:.1f}s, {target_sr}Hz") |
|
|
| |
| with torch.no_grad(): |
| sources = apply_model(model, wav, device=device, shifts=1, |
| split=True, overlap=0.25, progress=True) |
|
|
| |
| stem_names = model.sources |
| drums_idx = stem_names.index('drums') |
| drums_wav = sources[0, drums_idx] |
|
|
| |
| drums_mono = drums_wav.mean(dim=0).cpu().numpy() |
| print(f" β Extracted drums: {len(drums_mono) / target_sr:.1f}s") |
| |
| return drums_mono, target_sr |
|
|
|
|
| |
| |
| |
|
|
| def detect_onsets(y: np.ndarray, sr: int, |
| pre_pad: float = 0.005, |
| min_hit_dur: float = 0.03, |
| max_hit_dur: float = 0.8, |
| min_gap: float = 0.02, |
| energy_threshold_db: float = -40.0) -> list[DrumHit]: |
| """Detect drum hit onsets and segment into individual hits.""" |
| print("\n" + "=" * 60) |
| print("STAGE 2: Detecting drum hit onsets") |
| print("=" * 60) |
|
|
| |
| onset_env_low = librosa.onset.onset_strength( |
| y=y, sr=sr, fmin=20, fmax=250, aggregate=np.median |
| ) |
| onset_env_mid = librosa.onset.onset_strength( |
| y=y, sr=sr, fmin=250, fmax=4000, aggregate=np.median |
| ) |
| onset_env_high = librosa.onset.onset_strength( |
| y=y, sr=sr, fmin=4000, fmax=sr // 2, aggregate=np.median |
| ) |
|
|
| |
| def norm(x): |
| mx = x.max() |
| return x / mx if mx > 0 else x |
|
|
| onset_env = np.maximum(norm(onset_env_low), |
| np.maximum(norm(onset_env_mid), norm(onset_env_high))) |
|
|
| |
| wait_frames = max(1, int(min_gap * sr / 512)) |
| onsets_frames = librosa.onset.onset_detect( |
| onset_envelope=onset_env, |
| sr=sr, |
| wait=wait_frames, |
| pre_avg=3, |
| post_avg=3, |
| pre_max=3, |
| post_max=5, |
| backtrack=True, |
| units='frames' |
| ) |
| onset_times = librosa.frames_to_time(onsets_frames, sr=sr) |
|
|
| print(f" Raw onsets detected: {len(onset_times)}") |
|
|
| |
| hits = [] |
| energy_threshold = 10 ** (energy_threshold_db / 20) |
|
|
| for i, t in enumerate(onset_times): |
| start_sample = max(0, int((t - pre_pad) * sr)) |
|
|
| if i + 1 < len(onset_times): |
| next_onset_sample = int(onset_times[i + 1] * sr) |
| end_sample = min(next_onset_sample, start_sample + int(max_hit_dur * sr)) |
| else: |
| end_sample = min(len(y), start_sample + int(max_hit_dur * sr)) |
|
|
| segment = y[start_sample:end_sample] |
|
|
| if len(segment) < int(min_hit_dur * sr): |
| continue |
| rms = np.sqrt(np.mean(segment ** 2)) |
| if rms < energy_threshold: |
| continue |
|
|
| |
| fade_len = min(int(0.005 * sr), len(segment) // 4) |
| if fade_len > 0: |
| segment = segment.copy() |
| segment[-fade_len:] *= np.linspace(1, 0, fade_len) |
|
|
| spectral_centroid = float(librosa.feature.spectral_centroid( |
| y=segment, sr=sr |
| ).mean()) |
|
|
| hit = DrumHit( |
| audio=segment, |
| sr=sr, |
| onset_time=t, |
| duration=len(segment) / sr, |
| index=len(hits), |
| rms_energy=float(rms), |
| spectral_centroid=spectral_centroid, |
| ) |
| hits.append(hit) |
|
|
| print(f" β Valid hits after filtering: {len(hits)}") |
| return hits |
|
|
|
|
| |
| |
| |
|
|
| def rough_spectral_label(hit: DrumHit) -> str: |
| """Assign a rough drum type label based on spectral characteristics.""" |
| y, sr = hit.audio, hit.sr |
| centroid = hit.spectral_centroid |
|
|
| D = np.abs(librosa.stft(y, n_fft=2048)) |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) |
|
|
| low_energy = np.sum(D[(freqs >= 20) & (freqs < 200)] ** 2) |
| mid_energy = np.sum(D[(freqs >= 200) & (freqs < 4000)] ** 2) |
| high_energy = np.sum(D[(freqs >= 4000)] ** 2) |
| total = low_energy + mid_energy + high_energy + 1e-10 |
|
|
| low_ratio = low_energy / total |
| mid_ratio = mid_energy / total |
| high_ratio = high_energy / total |
| zcr = float(librosa.feature.zero_crossing_rate(y=y).mean()) |
|
|
| if low_ratio > 0.5 and centroid < 800: |
| return "kick" |
| elif high_ratio > 0.35 and centroid > 4000: |
| return "hihat_closed" if hit.duration < 0.15 else "hihat_open" |
| elif high_ratio > 0.25 and centroid > 3000: |
| return "cymbal" |
| elif mid_ratio > 0.4 and zcr > 0.1 and centroid > 1000: |
| return "snare" |
| elif low_ratio > 0.3 and mid_ratio > 0.3: |
| return "tom" |
| elif centroid > 2500: |
| return "perc_high" |
| else: |
| return "perc_low" |
|
|
|
|
| def spectral_separate_hit(hit: DrumHit) -> dict[str, np.ndarray]: |
| """Decompose a single hit into spectral bands (kick/snare/hihat ranges).""" |
| y, sr = hit.audio, hit.sr |
| D = librosa.stft(y, n_fft=2048) |
| freqs = librosa.fft_frequencies(sr=sr, n_fft=2048) |
|
|
| bands = { |
| "low": (20, 250), |
| "mid": (250, 4000), |
| "high": (4000, sr // 2) |
| } |
|
|
| results = {} |
| for name, (fmin, fmax) in bands.items(): |
| mask = (freqs >= fmin) & (freqs <= fmax) |
| D_band = np.zeros_like(D) |
| D_band[mask] = D[mask] |
| audio_band = librosa.istft(D_band, length=len(y)) |
|
|
| if np.sqrt(np.mean(audio_band ** 2)) > 0.001: |
| results[name] = audio_band |
|
|
| return results |
|
|
|
|
| def classify_and_separate_hits(hits: list[DrumHit], |
| separate_overlaps: bool = True) -> list[DrumHit]: |
| """Classify hits and optionally split overlapping sounds into sub-hits.""" |
| print("\n" + "=" * 60) |
| print("STAGE 3: Spectral classification & separation") |
| print("=" * 60) |
|
|
| all_hits = [] |
| overlap_count = 0 |
|
|
| for hit in hits: |
| label = rough_spectral_label(hit) |
| hit.rough_label = label |
|
|
| if separate_overlaps: |
| bands = spectral_separate_hit(hit) |
| if len(bands) >= 2: |
| energies = {k: np.sqrt(np.mean(v ** 2)) for k, v in bands.items()} |
| max_e = max(energies.values()) |
| significant = {k: v for k, v in bands.items() |
| if energies[k] > 0.15 * max_e} |
|
|
| if len(significant) >= 2: |
| overlap_count += 1 |
| band_labels = {"low": "kick", "mid": "snare", "high": "hihat"} |
| for band_name, band_audio in significant.items(): |
| sub_hit = DrumHit( |
| audio=band_audio, |
| sr=hit.sr, |
| onset_time=hit.onset_time, |
| duration=hit.duration, |
| index=len(all_hits), |
| rms_energy=float(np.sqrt(np.mean(band_audio ** 2))), |
| spectral_centroid=float(librosa.feature.spectral_centroid( |
| y=band_audio, sr=hit.sr |
| ).mean()), |
| rough_label=band_labels.get(band_name, "other"), |
| ) |
| all_hits.append(sub_hit) |
| continue |
|
|
| hit.index = len(all_hits) |
| all_hits.append(hit) |
|
|
| label_counts = defaultdict(int) |
| for h in all_hits: |
| label_counts[h.rough_label] += 1 |
|
|
| print(f" Overlapping hits decomposed: {overlap_count}") |
| print(f" Total hits after separation: {len(all_hits)}") |
| print(f" Label distribution:") |
| for label, count in sorted(label_counts.items(), key=lambda x: -x[1]): |
| print(f" {label}: {count}") |
|
|
| return all_hits |
|
|
|
|
| |
| |
| |
|
|
| def compute_librosa_embeddings(hits: list[DrumHit]) -> np.ndarray: |
| """Compute rich librosa feature embeddings (58-dim) for all hits.""" |
| embeddings = [] |
| for hit in hits: |
| y, sr = hit.audio, hit.sr |
|
|
| min_len = int(0.05 * sr) |
| if len(y) < min_len: |
| y = np.pad(y, (0, min_len - len(y))) |
|
|
| mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20) |
| mfcc_mean = mfcc.mean(axis=1) |
| mfcc_std = mfcc.std(axis=1) |
|
|
| centroid = librosa.feature.spectral_centroid(y=y, sr=sr) |
| bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr) |
| rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr) |
| contrast = librosa.feature.spectral_contrast(y=y, sr=sr, n_bands=4) |
| flatness = librosa.feature.spectral_flatness(y=y) |
| zcr = librosa.feature.zero_crossing_rate(y=y) |
| rms = librosa.feature.rms(y=y) |
|
|
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) |
| if len(onset_env) > 1: |
| onset_env_norm = onset_env / (onset_env.max() + 1e-10) |
| attack_feats = [ |
| onset_env_norm.mean(), |
| onset_env_norm.std(), |
| float(np.argmax(onset_env_norm)) / len(onset_env_norm), |
| onset_env_norm[-1] if len(onset_env_norm) > 0 else 0, |
| ] |
| else: |
| attack_feats = [0, 0, 0, 0] |
|
|
| feat = np.concatenate([ |
| mfcc_mean, |
| mfcc_std, |
| [centroid.mean(), centroid.std()], |
| [bandwidth.mean(), bandwidth.std()], |
| [rolloff.mean()], |
| contrast.mean(axis=1), |
| [flatness.mean()], |
| [zcr.mean()], |
| [rms.mean()], |
| attack_feats, |
| [hit.duration], |
| ]) |
| embeddings.append(feat) |
|
|
| embeddings = np.array(embeddings, dtype=np.float32) |
| mean = embeddings.mean(axis=0) |
| std = embeddings.std(axis=0) + 1e-8 |
| embeddings = (embeddings - mean) / std |
|
|
| return embeddings |
|
|
|
|
| def compute_clap_embeddings(hits: list[DrumHit], device: str = "cpu") -> np.ndarray: |
| """Compute CLAP audio embeddings (semantic, 512-dim).""" |
| from transformers import ClapModel, ClapProcessor |
|
|
| print(" Loading CLAP model (laion/larger_clap_general)...") |
| model = ClapModel.from_pretrained("laion/larger_clap_general").to(device) |
| processor = ClapProcessor.from_pretrained("laion/larger_clap_general") |
| model.eval() |
|
|
| clap_sr = 48000 |
| embeddings = [] |
|
|
| for i, hit in enumerate(hits): |
| y_48k = librosa.resample(hit.audio, orig_sr=hit.sr, target_sr=clap_sr) |
| min_samples = int(0.5 * clap_sr) |
| if len(y_48k) < min_samples: |
| y_48k = np.pad(y_48k, (0, min_samples - len(y_48k))) |
|
|
| inputs = processor(audios=y_48k, sampling_rate=clap_sr, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| audio_embed = model.get_audio_features(**inputs) |
| embeddings.append(audio_embed.squeeze().cpu().numpy()) |
|
|
| if (i + 1) % 50 == 0: |
| print(f" Embedded {i + 1}/{len(hits)}") |
|
|
| return np.array(embeddings, dtype=np.float32) |
|
|
|
|
| def cluster_hits(hits: list[DrumHit], |
| embeddings: np.ndarray, |
| min_clusters: int = 2, |
| max_clusters: int = 30) -> list[DrumCluster]: |
| """Cluster hits by embedding similarity, auto-selecting K via silhouette.""" |
| from sklearn.cluster import KMeans |
| from sklearn.metrics import silhouette_score |
|
|
| print("\n" + "=" * 60) |
| print("STAGE 4: Clustering similar drum hits") |
| print("=" * 60) |
|
|
| n = len(hits) |
| max_clusters = min(max_clusters, n - 1) |
| if max_clusters < min_clusters: |
| max_clusters = min_clusters |
|
|
| |
| label_groups = defaultdict(list) |
| for i, hit in enumerate(hits): |
| label_groups[hit.rough_label].append(i) |
|
|
| all_clusters = [] |
|
|
| for label, indices in label_groups.items(): |
| if len(indices) < 2: |
| cluster = DrumCluster( |
| cluster_id=len(all_clusters), |
| label=f"{label}_0", |
| hits=[hits[i] for i in indices] |
| ) |
| all_clusters.append(cluster) |
| continue |
|
|
| group_embeddings = embeddings[indices] |
| max_k = min(max(2, len(indices) // 3), 15) |
| best_k, best_score = 1, -1 |
|
|
| for k in range(2, max_k + 1): |
| try: |
| km = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300) |
| sub_labels = km.fit_predict(group_embeddings) |
| score = silhouette_score(group_embeddings, sub_labels) |
| if score > best_score: |
| best_k, best_score = k, score |
| except ValueError: |
| continue |
|
|
| if best_k >= 2: |
| km = KMeans(n_clusters=best_k, random_state=42, n_init=10) |
| sub_labels = km.fit_predict(group_embeddings) |
| else: |
| sub_labels = np.zeros(len(indices), dtype=int) |
|
|
| for sub_id in range(max(sub_labels) + 1): |
| member_mask = sub_labels == sub_id |
| member_indices = [indices[j] for j in range(len(indices)) if member_mask[j]] |
| cluster = DrumCluster( |
| cluster_id=len(all_clusters), |
| label=f"{label}_{sub_id}", |
| hits=[hits[i] for i in member_indices], |
| ) |
| all_clusters.append(cluster) |
|
|
| print(f" {label}: {len(indices)} hits β {best_k} sub-clusters " |
| f"(silhouette={best_score:.3f})") |
|
|
| print(f"\n β Total clusters: {len(all_clusters)}") |
| for c in all_clusters: |
| print(f" {c.label}: {c.count} hits") |
|
|
| return all_clusters |
|
|
|
|
| |
| |
| |
|
|
| def select_best_representatives(clusters: list[DrumCluster], |
| embeddings_dict: dict = None): |
| """Select the best representative hit from each cluster. |
| |
| Scoring: 60% representativeness (closest to centroid) + 40% energy (louder = cleaner). |
| """ |
| print("\n" + "=" * 60) |
| print("STAGE 5: Selecting best representatives") |
| print("=" * 60) |
|
|
| for cluster in clusters: |
| if cluster.count == 1: |
| cluster.best_hit_idx = 0 |
| continue |
|
|
| hit_features = [] |
| for hit in cluster.hits: |
| feat = np.concatenate([ |
| librosa.feature.mfcc(y=hit.audio, sr=hit.sr, n_mfcc=13).mean(axis=1), |
| [hit.rms_energy, hit.spectral_centroid, hit.duration] |
| ]) |
| hit_features.append(feat) |
| hit_features = np.array(hit_features) |
|
|
| mean = hit_features.mean(axis=0) |
| std = hit_features.std(axis=0) + 1e-8 |
| hit_features_norm = (hit_features - mean) / std |
|
|
| centroid = hit_features_norm.mean(axis=0) |
| centroid_dists = np.linalg.norm(hit_features_norm - centroid, axis=1) |
| centroid_scores = 1.0 - (centroid_dists / (centroid_dists.max() + 1e-8)) |
|
|
| energies = np.array([h.rms_energy for h in cluster.hits]) |
| energy_scores = energies / (energies.max() + 1e-8) |
|
|
| scores = 0.6 * centroid_scores + 0.4 * energy_scores |
| cluster.best_hit_idx = int(np.argmax(scores)) |
|
|
| print(f" {cluster.label}: selected hit {cluster.best_hit_idx} " |
| f"(score={scores[cluster.best_hit_idx]:.3f}, " |
| f"energy={cluster.hits[cluster.best_hit_idx].rms_energy:.4f})") |
|
|
|
|
| |
| |
| |
|
|
| def synthesize_from_cluster(cluster: DrumCluster) -> np.ndarray: |
| """ |
| Synthesize an 'optimal' sample by weighted-averaging cluster members. |
| |
| Aligns samples to their peak transient, normalizes lengths, then does a |
| weighted average in the time domain. This reduces noise/bleed while |
| preserving the core transient character. |
| """ |
| if cluster.count == 1: |
| return cluster.hits[0].audio.copy() |
|
|
| sr = cluster.hits[0].sr |
| target_len = int(np.median([len(h.audio) for h in cluster.hits])) |
|
|
| aligned = [] |
| weights = [] |
| peak_pos_target = None |
|
|
| for i, hit in enumerate(cluster.hits): |
| audio = hit.audio.copy() |
| peak_pos = np.argmax(np.abs(audio)) |
|
|
| if peak_pos_target is None: |
| peak_pos_target = peak_pos |
|
|
| |
| shift = peak_pos_target - peak_pos |
| if shift > 0: |
| audio = np.pad(audio, (shift, 0)) |
| elif shift < 0: |
| audio = audio[-shift:] |
|
|
| |
| if len(audio) >= target_len: |
| audio = audio[:target_len] |
| else: |
| audio = np.pad(audio, (0, target_len - len(audio))) |
|
|
| |
| peak = np.abs(audio).max() |
| if peak > 0: |
| audio = audio / peak |
|
|
| aligned.append(audio) |
|
|
| |
| if i == cluster.best_hit_idx: |
| weights.append(2.0) |
| else: |
| weights.append(1.0) |
|
|
| aligned = np.array(aligned) |
| weights = np.array(weights) |
| weights = weights / weights.sum() |
|
|
| synthesized = np.average(aligned, axis=0, weights=weights) |
|
|
| peak = np.abs(synthesized).max() |
| if peak > 0: |
| synthesized = synthesized * (0.95 / peak) |
|
|
| return synthesized |
|
|
|
|
| |
| |
| |
|
|
| def run_pipeline( |
| audio_path: str, |
| output_dir: str = "./drum_samples", |
| use_gpu: bool = True, |
| use_clap: bool = False, |
| separate_overlaps: bool = True, |
| synthesize: bool = True, |
| min_hit_dur: float = 0.03, |
| max_hit_dur: float = 0.8, |
| energy_threshold_db: float = -40.0, |
| save_intermediates: bool = True, |
| ): |
| """Run the full drum sample extraction pipeline.""" |
| device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu" |
| print(f"Device: {device}") |
| print(f"Input: {audio_path}") |
| print(f"Output: {output_dir}") |
|
|
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| drums_audio, drums_sr = extract_drums_demucs(audio_path, device=device) |
|
|
| if save_intermediates: |
| drums_path = output_dir / "drums_stem.wav" |
| sf.write(str(drums_path), drums_audio, drums_sr, subtype='PCM_24') |
| print(f" Saved drum stem: {drums_path}") |
|
|
| |
| hits = detect_onsets( |
| drums_audio, drums_sr, |
| min_hit_dur=min_hit_dur, |
| max_hit_dur=max_hit_dur, |
| energy_threshold_db=energy_threshold_db, |
| ) |
|
|
| if len(hits) == 0: |
| print("\nβ No drum hits detected! Try lowering energy_threshold_db.") |
| return |
|
|
| |
| hits = classify_and_separate_hits(hits, separate_overlaps=separate_overlaps) |
|
|
| if save_intermediates: |
| hits_dir = output_dir / "all_hits" |
| hits_dir.mkdir(exist_ok=True) |
| for hit in hits: |
| hit_path = hits_dir / f"hit_{hit.index:04d}_{hit.rough_label}_{hit.onset_time:.3f}s.wav" |
| hit.save(str(hit_path)) |
|
|
| |
| print("\n" + "=" * 60) |
| print("STAGE 4a: Computing embeddings") |
| print("=" * 60) |
|
|
| if use_clap: |
| embeddings = compute_clap_embeddings(hits, device=device) |
| print(f" β CLAP embeddings: {embeddings.shape}") |
| else: |
| embeddings = compute_librosa_embeddings(hits) |
| print(f" β Librosa embeddings: {embeddings.shape}") |
|
|
| for i, hit in enumerate(hits): |
| hit.embedding = embeddings[i] |
|
|
| clusters = cluster_hits(hits, embeddings) |
|
|
| |
| select_best_representatives(clusters) |
|
|
| |
| if synthesize: |
| print("\n" + "=" * 60) |
| print("STAGE 6: Synthesizing optimal samples") |
| print("=" * 60) |
| for cluster in clusters: |
| if cluster.count >= 2: |
| cluster.synthesized = synthesize_from_cluster(cluster) |
| print(f" {cluster.label}: synthesized from {cluster.count} hits") |
|
|
| |
| print("\n" + "=" * 60) |
| print("EXPORT: Saving results") |
| print("=" * 60) |
|
|
| samples_dir = output_dir / "samples" |
| samples_dir.mkdir(exist_ok=True) |
|
|
| if synthesize: |
| synth_dir = output_dir / "synthesized" |
| synth_dir.mkdir(exist_ok=True) |
|
|
| manifest = [] |
| for cluster in clusters: |
| best = cluster.best_hit |
|
|
| sample_name = f"{cluster.label}__best.wav" |
| sample_path = samples_dir / sample_name |
| best.save(str(sample_path)) |
|
|
| entry = { |
| "cluster_id": cluster.cluster_id, |
| "label": cluster.label, |
| "count": cluster.count, |
| "best_sample": str(sample_path), |
| "best_onset_time": best.onset_time, |
| "best_duration": best.duration, |
| "best_rms_energy": best.rms_energy, |
| "best_spectral_centroid": best.spectral_centroid, |
| } |
|
|
| if synthesize and cluster.synthesized is not None: |
| synth_name = f"{cluster.label}__synthesized.wav" |
| synth_path = synth_dir / synth_name |
| sf.write(str(synth_path), cluster.synthesized, best.sr, subtype='PCM_24') |
| entry["synthesized_sample"] = str(synth_path) |
|
|
| manifest.append(entry) |
| print(f" β {cluster.label}: {cluster.count} hits β {sample_path.name}") |
|
|
| |
| manifest_path = output_dir / "manifest.json" |
| with open(manifest_path, "w") as f: |
| json.dump(manifest, f, indent=2) |
| print(f"\n Manifest saved: {manifest_path}") |
|
|
| |
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print(f" Input: {audio_path}") |
| print(f" Drum stem: {output_dir / 'drums_stem.wav'}") |
| print(f" Total hits: {len(hits)}") |
| print(f" Clusters: {len(clusters)}") |
| print(f" Samples saved: {samples_dir}") |
| if synthesize: |
| print(f" Synthesized: {synth_dir}") |
| print(f" Manifest: {manifest_path}") |
|
|
| return clusters |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Extract individual drum samples from an audio file", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| %(prog)s song.mp3 -o ./my_samples |
| %(prog)s drums.wav -o ./samples --no-gpu |
| %(prog)s song.wav -o ./samples --clap # Use CLAP for semantic clustering |
| %(prog)s song.wav -o ./samples --no-separate # Don't decompose overlaps |
| %(prog)s song.wav -o ./samples --no-synthesize # Skip synthesis step |
| """ |
| ) |
| parser.add_argument("input", help="Input audio file (mp3, wav, flac, etc.)") |
| parser.add_argument("-o", "--output-dir", default="./drum_samples", |
| help="Output directory (default: ./drum_samples)") |
| parser.add_argument("--no-gpu", action="store_true", |
| help="Force CPU-only processing") |
| parser.add_argument("--clap", action="store_true", |
| help="Use CLAP embeddings for clustering (slower, more semantic)") |
| parser.add_argument("--no-separate", action="store_true", |
| help="Don't separate overlapping drum sounds") |
| parser.add_argument("--no-synthesize", action="store_true", |
| help="Don't synthesize optimal samples from clusters") |
| parser.add_argument("--no-intermediates", action="store_true", |
| help="Don't save intermediate files (drum stem, individual hits)") |
| parser.add_argument("--min-hit-dur", type=float, default=0.03, |
| help="Minimum hit duration in seconds (default: 0.03)") |
| parser.add_argument("--max-hit-dur", type=float, default=0.8, |
| help="Maximum hit duration in seconds (default: 0.8)") |
| parser.add_argument("--energy-threshold", type=float, default=-40.0, |
| help="Energy threshold in dB for hit filtering (default: -40)") |
|
|
| args = parser.parse_args() |
|
|
| if not os.path.exists(args.input): |
| print(f"Error: Input file not found: {args.input}") |
| sys.exit(1) |
|
|
| run_pipeline( |
| audio_path=args.input, |
| output_dir=args.output_dir, |
| use_gpu=not args.no_gpu, |
| use_clap=args.clap, |
| separate_overlaps=not args.no_separate, |
| synthesize=not args.no_synthesize, |
| min_hit_dur=args.min_hit_dur, |
| max_hit_dur=args.max_hit_dur, |
| energy_threshold_db=args.energy_threshold, |
| save_intermediates=not args.no_intermediates, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|