Wav2Vec2 for drum-kit classification
Fine-tuned facebook/wav2vec2-base for audio classification of single drum/percussion sounds into 10 classes.
Classes
- clap, conga, crash, cymbal, hat, kick, ride, rim, snare, tom
Usage
import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
model_id = "airasoul/wav2vec2-base-drum-kit" # e.g. username/wav2vec2-base-drum-kit
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForAudioClassification.from_pretrained(model_id)
model.eval()
# Load a WAV (16 kHz mono)
audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
inputs = feature_extractor(
audio, sampling_rate=16000, max_length=48000, # 3 s at 16 kHz
truncation=True, return_tensors="pt", padding=True
)
with torch.no_grad():
logits = model(**inputs).logits
pred_id = logits.argmax(dim=-1).item()
label = model.config.id2label.get(pred_id) or model.config.id2label.get(str(pred_id))
print(label) # e.g. "kick"
Training
- Base: facebook/wav2vec2-base
- Task: Single-label classification over 10 drum classes
- Data: Custom drum-kit dataset with augmentation (time stretch, noise, gain)
- Input: 16 kHz mono, up to 3 s (truncated or padded)
Results
- Validation accuracy: 95.7% (epoch 10)
- Test accuracy: 97.0% (300 samples, held-out)
| Class | Precision | Recall | F1-score |
|---|---|---|---|
| clap | 1.00 | 1.00 | 1.00 |
| conga | 0.96 | 0.93 | 0.95 |
| crash | 0.97 | 0.97 | 0.97 |
| cymbal | 1.00 | 0.91 | 0.95 |
| hat | 1.00 | 0.97 | 0.98 |
| kick | 1.00 | 0.94 | 0.97 |
| ride | 0.94 | 1.00 | 0.97 |
| rim | 1.00 | 1.00 | 1.00 |
| snare | 0.89 | 0.96 | 0.93 |
| tom | 0.92 | 1.00 | 0.96 |
Limitations
- Trained on short, single-hit drum sounds. Performance may drop on long mixes, multiple overlapping sounds, or very different recording conditions.
- Downloads last month
- 51