airasoul's picture
Upload README.md with huggingface_hub
1d2aca5 verified
metadata
license: mit
language: en
pipeline_tag: audio-classification
tags:
  - audio
  - audio-classification
  - wav2vec2
  - drum
  - percussion
datasets:
  - airasoul/drum-kit
metrics:
  - accuracy
  - f1

Wav2Vec2 for drum-kit classification

Fine-tuned facebook/wav2vec2-base for audio classification of single drum/percussion sounds into 10 classes.

Classes

  • clap, conga, crash, cymbal, hat, kick, ride, rim, snare, tom

Usage

import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

model_id = "airasoul/wav2vec2-base-drum-kit"  # e.g. username/wav2vec2-base-drum-kit
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForAudioClassification.from_pretrained(model_id)
model.eval()

# Load a WAV (16 kHz mono)
audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
inputs = feature_extractor(
    audio, sampling_rate=16000, max_length=48000,  # 3 s at 16 kHz
    truncation=True, return_tensors="pt", padding=True
)
with torch.no_grad():
    logits = model(**inputs).logits
pred_id = logits.argmax(dim=-1).item()
label = model.config.id2label.get(pred_id) or model.config.id2label.get(str(pred_id))
print(label)  # e.g. "kick"

Training

  • Base: facebook/wav2vec2-base
  • Task: Single-label classification over 10 drum classes
  • Data: Custom drum-kit dataset with augmentation (time stretch, noise, gain)
  • Input: 16 kHz mono, up to 3 s (truncated or padded)

Results

  • Validation accuracy: 95.7% (epoch 10)
  • Test accuracy: 97.0% (300 samples, held-out)
Class Precision Recall F1-score
clap 1.00 1.00 1.00
conga 0.96 0.93 0.95
crash 0.97 0.97 0.97
cymbal 1.00 0.91 0.95
hat 1.00 0.97 0.98
kick 1.00 0.94 0.97
ride 0.94 1.00 0.97
rim 1.00 1.00 1.00
snare 0.89 0.96 0.93
tom 0.92 1.00 0.96

Limitations

  • Trained on short, single-hit drum sounds. Performance may drop on long mixes, multiple overlapping sounds, or very different recording conditions.