Wav2Vec2 for drum-kit classification

Fine-tuned facebook/wav2vec2-base for audio classification of single drum/percussion sounds into 10 classes.

Classes

  • clap, conga, crash, cymbal, hat, kick, ride, rim, snare, tom

Usage

import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

model_id = "airasoul/wav2vec2-base-drum-kit"  # e.g. username/wav2vec2-base-drum-kit
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForAudioClassification.from_pretrained(model_id)
model.eval()

# Load a WAV (16 kHz mono)
audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
inputs = feature_extractor(
    audio, sampling_rate=16000, max_length=48000,  # 3 s at 16 kHz
    truncation=True, return_tensors="pt", padding=True
)
with torch.no_grad():
    logits = model(**inputs).logits
pred_id = logits.argmax(dim=-1).item()
label = model.config.id2label.get(pred_id) or model.config.id2label.get(str(pred_id))
print(label)  # e.g. "kick"

Training

  • Base: facebook/wav2vec2-base
  • Task: Single-label classification over 10 drum classes
  • Data: Custom drum-kit dataset with augmentation (time stretch, noise, gain)
  • Input: 16 kHz mono, up to 3 s (truncated or padded)

Results

  • Validation accuracy: 95.7% (epoch 10)
  • Test accuracy: 97.0% (300 samples, held-out)
Class Precision Recall F1-score
clap 1.00 1.00 1.00
conga 0.96 0.93 0.95
crash 0.97 0.97 0.97
cymbal 1.00 0.91 0.95
hat 1.00 0.97 0.98
kick 1.00 0.94 0.97
ride 0.94 1.00 0.97
rim 1.00 1.00 1.00
snare 0.89 0.96 0.93
tom 0.92 1.00 0.96

Limitations

  • Trained on short, single-hit drum sounds. Performance may drop on long mixes, multiple overlapping sounds, or very different recording conditions.
Downloads last month
51
Safetensors
Model size
94.6M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train airasoul/wav2vec2-base-drum-kit