airasoul/drum-kit
Viewer • Updated • 3k • 58 • 2
Fine-tuned facebook/wav2vec2-base for audio classification of single drum/percussion sounds into 10 classes.
import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
model_id = "airasoul/wav2vec2-base-drum-kit" # e.g. username/wav2vec2-base-drum-kit
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForAudioClassification.from_pretrained(model_id)
model.eval()
# Load a WAV (16 kHz mono)
audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
inputs = feature_extractor(
audio, sampling_rate=16000, max_length=48000, # 3 s at 16 kHz
truncation=True, return_tensors="pt", padding=True
)
with torch.no_grad():
logits = model(**inputs).logits
pred_id = logits.argmax(dim=-1).item()
label = model.config.id2label.get(pred_id) or model.config.id2label.get(str(pred_id))
print(label) # e.g. "kick"
| Class | Precision | Recall | F1-score |
|---|---|---|---|
| clap | 1.00 | 1.00 | 1.00 |
| conga | 0.96 | 0.93 | 0.95 |
| crash | 0.97 | 0.97 | 0.97 |
| cymbal | 1.00 | 0.91 | 0.95 |
| hat | 1.00 | 0.97 | 0.98 |
| kick | 1.00 | 0.94 | 0.97 |
| ride | 0.94 | 1.00 | 0.97 |
| rim | 1.00 | 1.00 | 1.00 |
| snare | 0.89 | 0.96 | 0.93 |
| tom | 0.92 | 1.00 | 0.96 |