rikhoffbauer2's picture
Add README with architecture docs and upgrade paths
b1ba411 verified
# πŸ₯ Drum Sample Extractor
Extract individual drum samples (kick, snare, hi-hat, etc.) from any audio file. The pipeline isolates drums, detects individual hits, separates overlapping sounds, clusters identical samples, and exports the best representative from each group.
## Pipeline Architecture
```
song.mp3
β”‚
β–Ό [1] HTDemucs v4 (fine-tuned) ─── stem separation
drums.wav
β”‚
β–Ό [2] Multi-band onset detection ─── librosa (backtracking, 3-band)
hits/hit_001.wav, hit_002.wav, ...
β”‚
β–Ό [3] Spectral band decomposition ─── separate overlapping kick+snare+hihat
hits_separated/{kick/, snare/, hihat/}
β”‚
β–Ό [4] Feature embeddings + clustering ─── librosa (58-dim) or CLAP (512-dim)
β”‚ Auto-K via silhouette score
β”‚
β–Ό [5] Best representative selection ─── 60% centroid-proximity + 40% energy
β”‚
β–Ό [6] Optional: weighted synthesis ─── peak-aligned averaging across cluster
β”‚
β–Ό EXPORT
samples/kick_0__best.wav # best real sample per cluster
synthesized/kick_0__synthesized.wav # synthetic "ideal" version
manifest.json # metadata for all clusters
```
## Quick Start
```bash
pip install demucs librosa soundfile scikit-learn numpy torch transformers
python drum_extractor.py song.mp3 -o ./my_samples
```
## Usage
```bash
# Basic - extract from any audio file
python drum_extractor.py song.mp3 -o ./samples
# CPU-only (no GPU required for any stage)
python drum_extractor.py song.wav -o ./samples --no-gpu
# Use CLAP embeddings for semantic clustering (slower but more accurate)
python drum_extractor.py song.wav -o ./samples --clap
# Skip overlap separation (faster, but simultaneous hits stay merged)
python drum_extractor.py song.wav -o ./samples --no-separate
# Skip synthesis (only export real samples, no averaging)
python drum_extractor.py song.wav -o ./samples --no-synthesize
# Tune detection sensitivity
python drum_extractor.py song.wav -o ./samples \
--min-hit-dur 0.05 \
--max-hit-dur 1.0 \
--energy-threshold -35
```
## Output Structure
```
output_dir/
β”œβ”€β”€ drums_stem.wav # Isolated drum track from Demucs
β”œβ”€β”€ all_hits/ # Every detected hit (intermediate)
β”‚ β”œβ”€β”€ hit_0000_kick_0.500s.wav
β”‚ β”œβ”€β”€ hit_0001_snare_1.000s.wav
β”‚ └── ...
β”œβ”€β”€ samples/ # Best representative per cluster
β”‚ β”œβ”€β”€ kick_0__best.wav
β”‚ β”œβ”€β”€ snare_0__best.wav
β”‚ β”œβ”€β”€ hihat_closed_0__best.wav
β”‚ └── ...
β”œβ”€β”€ synthesized/ # Synthesized "ideal" samples
β”‚ β”œβ”€β”€ kick_0__synthesized.wav
β”‚ └── ...
└── manifest.json # Full metadata
```
## How Each Stage Works
### Stage 1: Drum Stem Extraction
Uses [HTDemucs v4 fine-tuned](https://github.com/facebookresearch/demucs) (`htdemucs_ft`) β€” the current SOTA for music source separation at **8.4 dB SDR** on drums (MUSDB18-HQ). Falls back to `htdemucs` if the fine-tuned variant is unavailable.
### Stage 2: Onset Detection
Multi-band onset detection using librosa:
- **Low band** (20–250 Hz): catches kicks
- **Mid band** (250–4000 Hz): catches snares and toms
- **High band** (4000+ Hz): catches cymbals and hi-hats
Each band is normalized independently, then combined via element-wise max. Backtracking snaps onsets to the true attack start.
### Stage 3: Spectral Classification & Overlap Separation
Each hit is classified by a spectral decision tree:
- **Kick**: >50% low-band energy, centroid < 800 Hz
- **Snare**: >40% mid-band energy, high ZCR, centroid > 1000 Hz
- **Hi-hat (closed/open)**: >35% high-band energy, centroid > 4000 Hz
- **Cymbal/Tom/Percussion**: remaining combinations
When two bands both carry >15% of peak energy, the hit is split into separate sub-hits (one per band).
### Stage 4: Embedding & Clustering
**Default (librosa, 58-dim)**: MFCCs (mean+std), spectral centroid/bandwidth/rolloff/contrast/flatness, ZCR, RMS, onset envelope shape, duration. Z-score normalized.
**Optional (CLAP, 512-dim)**: `laion/larger_clap_general` β€” semantic audio embeddings via Contrastive Language-Audio Pretraining. Better at distinguishing subtly different drum types but slower.
Clustering is hierarchical: first group by rough spectral label, then sub-cluster within each group using KMeans with auto-K selection via silhouette score.
### Stage 5: Best Representative Selection
Each cluster's "best" hit is selected by a weighted score:
- **60% representativeness**: closest to cluster centroid in MFCC space
- **40% energy**: higher RMS = cleaner transient with less bleed
### Stage 6: Synthesis (Optional)
Creates an "ideal" sample by peak-aligned weighted averaging:
1. Align all cluster members to their peak transient
2. Normalize amplitudes
3. Weighted average (best hit gets 2Γ— weight)
4. This reduces random noise/bleed while preserving the shared transient character
## Upgrade Paths
For higher quality on specific stages, these drop-in replacements are available:
| Stage | Current | Upgrade | Benefit |
|-------|---------|---------|---------|
| 3 (overlap separation) | Spectral bands | [AudioSep](https://huggingface.co/spaces/Audio-AGI/AudioSep) | Text-queried separation ("kick drum"), 10.5 dB SDRi |
| 3 (overlap separation) | Spectral bands | [SAM Audio](https://huggingface.co/facebook/sam-audio-large) | Diffusion-based + temporal span prompts (gated, Meta license) |
| 4 (clustering) | librosa features | CLAP embeddings (`--clap`) | Semantic similarity, better cross-genre generalization |
| 2 (onset detection) | librosa | [madmom](https://github.com/CPJKU/madmom) RNNOnsetProcessor | 0.89 F1 on ENST-drums (needs Python ≀3.10) |
## Requirements
```
demucs>=4.0
librosa>=0.10
soundfile
scikit-learn
numpy
torch
transformers # only needed with --clap
```
## License
MIT