Bhaveen commited on
Commit
5c9f529
·
verified ·
1 Parent(s): c3f285c

Added readme

Browse files
Files changed (1) hide show
  1. README.md +203 -16
README.md CHANGED
@@ -1,21 +1,208 @@
1
  ---
2
  language: en
3
- license: apache-2.0
4
  tags:
5
- - machine-learning
6
- - transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ---
8
 
9
- 'eval_loss': 1.063901662826538,
10
-
11
- 'eval_roc_auc': 0.9858799160024966,
12
-
13
- 'eval_accuracy': 0.9333333333333333,
14
-
15
- 'eval_runtime': 14.1816,
16
-
17
- 'eval_samples_per_second': 25.385,
18
-
19
- 'eval_steps_per_second': 25.385,
20
-
21
- 'epoch': 5.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  language: en
3
+ license: mit
4
  tags:
5
+ - audio
6
+ - audio-classification
7
+ - musical-instruments
8
+ - wav2vec2
9
+ - transformers
10
+ - pytorch
11
+ datasets:
12
+ - custom
13
+ metrics:
14
+ - accuracy
15
+ - roc_auc
16
+ model-index:
17
+ - name: epoch_musical_instruments_identification_2
18
+ results:
19
+ - task:
20
+ type: audio-classification
21
+ name: Musical Instrument Classification
22
+ metrics:
23
+ - type: accuracy
24
+ value: 0.9333
25
+ name: Accuracy
26
+ - type: roc_auc
27
+ value: 0.9859
28
+ name: ROC AUC (Macro)
29
+ - type: loss
30
+ value: 1.0639
31
+ name: Validation Loss
32
+ base_model:
33
+ - facebook/wav2vec2-base-960h
34
  ---
35
 
36
+ # Musical Instrument Classification Model
37
+
38
+ This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy.
39
+
40
+ ## Model Description
41
+
42
+ - **Model type:** Audio Classification
43
+ - **Base model:** facebook/wav2vec2-base-960h
44
+ - **Language:** Audio (no specific language)
45
+ - **License:** MIT
46
+ - **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class)
47
+
48
+ ## Performance
49
+
50
+ The model achieves excellent performance on the evaluation set after 5 epochs of training:
51
+
52
+ - **Final Accuracy:** 93.33%
53
+ - **Final ROC AUC (Macro):** 98.59%
54
+ - **Final Validation Loss:** 1.064
55
+ - **Evaluation Runtime:** 14.18 seconds
56
+ - **Evaluation Speed:** 25.39 samples/second
57
+
58
+ ### Training Progress
59
+
60
+ | Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy |
61
+ |-------|---------------|-----------------|---------|----------|
62
+ | 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 |
63
+ | 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 |
64
+ | 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 |
65
+ | 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 |
66
+ | 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 |
67
+
68
+ ## Supported Instruments
69
+
70
+ The model can classify the following 9 musical instruments:
71
+
72
+ 1. **Acoustic Guitar**
73
+ 2. **Bass Guitar**
74
+ 3. **Drum Set**
75
+ 4. **Electric Guitar**
76
+ 5. **Flute**
77
+ 6. **Hi-Hats**
78
+ 7. **Keyboard**
79
+ 8. **Trumpet**
80
+ 9. **Violin**
81
+
82
+ ## Usage
83
+
84
+ ### Quick Start with Pipeline
85
+
86
+ ```python
87
+ from transformers import pipeline
88
+ import torchaudio
89
+
90
+ # Load the classification pipeline
91
+ classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2")
92
+
93
+ # Load and preprocess audio
94
+ audio, rate = torchaudio.load("your_audio_file.wav")
95
+ transform = torchaudio.transforms.Resample(rate, 16000)
96
+ audio = transform(audio).numpy().reshape(-1)[:48000]
97
+
98
+ # Classify the audio
99
+ result = classifier(audio)
100
+ print(result)
101
+ ```
102
+
103
+ ### Using Transformers Directly
104
+
105
+ ```python
106
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
107
+ import torchaudio
108
+ import torch
109
+
110
+ # Load model and feature extractor
111
+ model_name = "Bhaveen/epoch_musical_instruments_identification_2"
112
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
113
+ model = AutoModelForAudioClassification.from_pretrained(model_name)
114
+
115
+ # Load and preprocess audio
116
+ audio, rate = torchaudio.load("your_audio_file.wav")
117
+ transform = torchaudio.transforms.Resample(rate, 16000)
118
+ audio = transform(audio).numpy().reshape(-1)[:48000]
119
+
120
+ # Extract features and make prediction
121
+ inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
122
+ with torch.no_grad():
123
+ outputs = model(**inputs)
124
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
125
+ predicted_class = torch.argmax(predictions, dim=-1)
126
+
127
+ print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}")
128
+ ```
129
+
130
+ ## Training Details
131
+
132
+ ### Dataset and Preprocessing
133
+
134
+ - **Custom dataset** with audio recordings of 9 musical instruments
135
+ - **Train/Test Split:** 80/20 using file numbering (files < 160 for training)
136
+ - **Data Balancing:** Random oversampling applied to minority classes
137
+ - **Audio Preprocessing:**
138
+ - Resampling to 16,000 Hz
139
+ - Fixed length of 48,000 samples (3 seconds)
140
+ - Truncation of longer audio files
141
+
142
+ ### Training Configuration
143
+
144
+ ```python
145
+ # Training hyperparameters
146
+ batch_size = 1
147
+ gradient_accumulation_steps = 4
148
+ learning_rate = 5e-6
149
+ num_train_epochs = 5
150
+ warmup_steps = 50
151
+ weight_decay = 0.02
152
+ ```
153
+
154
+ ### Model Architecture
155
+
156
+ - **Base Model:** facebook/wav2vec2-base-960h
157
+ - **Classification Head:** Added for 9-class classification
158
+ - **Parameters:** ~95M trainable parameters
159
+ - **Features:** Wav2Vec2 audio representations with fine-tuned classification layer
160
+
161
+ ## Technical Specifications
162
+
163
+ - **Audio Format:** WAV files
164
+ - **Sample Rate:** 16,000 Hz
165
+ - **Input Length:** 3 seconds (48,000 samples)
166
+ - **Model Framework:** PyTorch + Transformers
167
+ - **Inference Device:** GPU recommended (CUDA)
168
+
169
+ ## Evaluation Metrics
170
+
171
+ The model uses the following evaluation metrics:
172
+
173
+ - **Accuracy:** Standard classification accuracy
174
+ - **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach
175
+ - **Multi-class Classification:** Softmax probabilities for all 9 instrument classes
176
+
177
+
178
+
179
+ ## Limitations and Considerations
180
+
181
+ 1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter)
182
+ 2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results
183
+ 3. **Audio Quality:** Performance depends on audio quality and recording conditions
184
+ 4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance
185
+ 5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles
186
+
187
+ ## Training Environment
188
+
189
+ - **Platform:** Google Colab
190
+ - **GPU:** CUDA-enabled device
191
+ - **Libraries:**
192
+ - transformers==4.28.1
193
+ - torchaudio==0.12
194
+ - datasets
195
+ - evaluate
196
+ - imblearn
197
+
198
+ ## Model Files
199
+
200
+ The repository contains:
201
+ - Model weights and configuration
202
+ - Feature extractor configuration
203
+ - Training logs and metrics
204
+ - Label mappings (id2label, label2id)
205
+
206
+ ---
207
+
208
+ *Model trained as part of a hackathon project*