File size: 6,580 Bytes
573cbe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
---
language: en
license: mit
tags:
- audio
- audio-classification
- musical-instruments
- wav2vec2
- transformers
- pytorch
datasets:
- custom
metrics:
- accuracy
- roc_auc
model-index:
- name: epoch_musical_instruments_identification_2
results:
- task:
type: audio-classification
name: Musical Instrument Classification
metrics:
- type: accuracy
value: 0.9333
name: Accuracy
- type: roc_auc
value: 0.9859
name: ROC AUC (Macro)
- type: loss
value: 1.0639
name: Validation Loss
base_model:
- Bhaveen/Musical-Instrument-Classification
library_name: transformers.js
pipeline_tag: audio-classification
---
# Musical-Instrument-Classification (ONNX)
This is an ONNX version of [Bhaveen/Musical-Instrument-Classification](https://huggingface.co/Bhaveen/Musical-Instrument-Classification). It was automatically converted and uploaded using [this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx).
## Usage with Transformers.js
See the pipeline documentation for `audio-classification`: https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AudioClassificationPipeline
---
# Musical Instrument Classification Model
This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy.
## Model Description
- **Model type:** Audio Classification
- **Base model:** facebook/wav2vec2-base-960h
- **Language:** Audio (no specific language)
- **License:** MIT
- **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class)
## Performance
The model achieves excellent performance on the evaluation set after 5 epochs of training:
- **Final Accuracy:** 93.33%
- **Final ROC AUC (Macro):** 98.59%
- **Final Validation Loss:** 1.064
- **Evaluation Runtime:** 14.18 seconds
- **Evaluation Speed:** 25.39 samples/second
### Training Progress
| Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy |
|-------|---------------|-----------------|---------|----------|
| 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 |
| 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 |
| 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 |
| 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 |
| 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 |
## Supported Instruments
The model can classify the following 9 musical instruments:
1. **Acoustic Guitar**
2. **Bass Guitar**
3. **Drum Set**
4. **Electric Guitar**
5. **Flute**
6. **Hi-Hats**
7. **Keyboard**
8. **Trumpet**
9. **Violin**
## Usage
### Quick Start with Pipeline
```python
from transformers import pipeline
import torchaudio
# Load the classification pipeline
classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2")
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Classify the audio
result = classifier(audio)
print(result)
```
### Using Transformers Directly
```python
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import torch
# Load model and feature extractor
model_name = "Bhaveen/epoch_musical_instruments_identification_2"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Extract features and make prediction
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1)
print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}")
```
## Training Details
### Dataset and Preprocessing
- **Custom dataset** with audio recordings of 9 musical instruments
- **Train/Test Split:** 80/20 using file numbering (files < 160 for training)
- **Data Balancing:** Random oversampling applied to minority classes
- **Audio Preprocessing:**
- Resampling to 16,000 Hz
- Fixed length of 48,000 samples (3 seconds)
- Truncation of longer audio files
### Training Configuration
```python
# Training hyperparameters
batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 5e-6
num_train_epochs = 5
warmup_steps = 50
weight_decay = 0.02
```
### Model Architecture
- **Base Model:** facebook/wav2vec2-base-960h
- **Classification Head:** Added for 9-class classification
- **Parameters:** ~95M trainable parameters
- **Features:** Wav2Vec2 audio representations with fine-tuned classification layer
## Technical Specifications
- **Audio Format:** WAV files
- **Sample Rate:** 16,000 Hz
- **Input Length:** 3 seconds (48,000 samples)
- **Model Framework:** PyTorch + Transformers
- **Inference Device:** GPU recommended (CUDA)
## Evaluation Metrics
The model uses the following evaluation metrics:
- **Accuracy:** Standard classification accuracy
- **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach
- **Multi-class Classification:** Softmax probabilities for all 9 instrument classes
## Limitations and Considerations
1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter)
2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results
3. **Audio Quality:** Performance depends on audio quality and recording conditions
4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance
5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles
## Training Environment
- **Platform:** Google Colab
- **GPU:** CUDA-enabled device
- **Libraries:**
- transformers==4.28.1
- torchaudio==0.12
- datasets
- evaluate
- imblearn
## Model Files
The repository contains:
- Model weights and configuration
- Feature extractor configuration
- Training logs and metrics
- Label mappings (id2label, label2id)
---
*Model trained as part of a hackathon project*
|