gulabjam commited on
Commit
b243717
Β·
1 Parent(s): 2ffbfde

Added ReadMe

Browse files
Files changed (1) hide show
  1. AST_README.md +225 -0
AST_README.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio Spectrogram Transformer (AST) for Music Genre Classification
2
+
3
+ Fine-tuned [Audio Spectrogram Transformer](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) for classifying audio tracks into **10 music genres**. This model achieved the best performance among all approaches tried in this project, reaching a **macro F1 of 0.886 on validation** and **0.857 on the Kaggle leaderboard**.
4
+
5
+ ---
6
+
7
+ ## Table of Contents
8
+
9
+ - [Overview](#overview)
10
+ - [Model Architecture](#model-architecture)
11
+ - [Preprocessing Pipeline](#preprocessing-pipeline)
12
+ - [Training](#training)
13
+ - [Results](#results)
14
+ - [Usage](#usage)
15
+ - [File Structure](#file-structure)
16
+ - [Acknowledgements](#acknowledgements)
17
+
18
+ ---
19
+
20
+ ## Overview
21
+
22
+ The Audio Spectrogram Transformer (AST) is a convolution-free, purely attention-based model for audio classification. It was originally pretrained on [AudioSet](https://research.google.com/audioset/) and is fine-tuned here on a custom **messy_mashup** music genre dataset with 10 genres:
23
+
24
+ > blues, classical, country, disco, hiphop, jazz, metal, pop, reggae, rock
25
+
26
+ Each training sample is synthesized on-the-fly by mixing separated stems (drums, vocals, bass, other) from a random song and injecting environmental noise from the ESC-50 dataset.
27
+
28
+ ---
29
+
30
+ ## Model Architecture
31
+
32
+ ```
33
+ Pretrained Checkpoint: MIT/ast-finetuned-audioset-10-10-0.4593
34
+
35
+ Input: Mel spectrogram (1024 frames Γ— 128 mel bins)
36
+ β†’ Patch embedding (16Γ—16 patches)
37
+ β†’ 12-layer Vision Transformer encoder
38
+ β†’ [CLS] token pooling
39
+ β†’ Linear classifier (527 β†’ 10 classes, re-initialized)
40
+ ```
41
+
42
+ The classification head is replaced with a 10-class output layer using `ignore_mismatched_sizes=True`. All layers are fine-tuned end-to-end.
43
+
44
+ ```python
45
+ class MusicGenreAST(nn.Module):
46
+ def __init__(self, num_classes):
47
+ super(MusicGenreAST, self).__init__()
48
+ self.ast = ASTForAudioClassification.from_pretrained(
49
+ "MIT/ast-finetuned-audioset-10-10-0.4593",
50
+ num_labels=num_classes,
51
+ ignore_mismatched_sizes=True
52
+ )
53
+
54
+ def forward(self, x):
55
+ outputs = self.ast(x)
56
+ return outputs
57
+ ```
58
+
59
+ ---
60
+
61
+ ## Preprocessing Pipeline
62
+
63
+ ### Audio Construction (Training)
64
+
65
+ 1. **Genre selection**: A random genre is chosen per sample
66
+ 2. **Stem loading**: Each of the 4 stems (drums, vocals, bass, other) is loaded at 16 kHz from a random song, starting at a random offset within the track
67
+ 3. **Stem dropout**: Each stem has a 15% chance of being excluded β€” this teaches the model to classify with incomplete information
68
+ 4. **Random gain**: Each included stem is scaled by a random factor in `[0.4, 1.2]` to simulate varying mix balances
69
+ 5. **Mixing**: All included stems are summed and peak-normalized
70
+ 6. **Noise injection**: A random ESC-50 clip is added at a random SNR (noise divisor uniformly sampled from `[2.0, 8.0]`)
71
+
72
+ ### Feature Extraction
73
+
74
+ | Parameter | Value |
75
+ |-----------|-------|
76
+ | Sample rate | 16,000 Hz |
77
+ | Duration | 10 seconds |
78
+ | Mel bands | 128 |
79
+ | FFT size | 400 |
80
+ | Hop length | 160 |
81
+ | Target frames | 1,024 |
82
+ | Normalization | `(mel_dB + 4.26) / 4.56` |
83
+
84
+ The mel spectrogram is transposed to shape `(1024, 128)` β€” 1024 time frames Γ— 128 mel bins β€” matching the AST's expected input format. Shorter clips are zero-padded; longer clips are truncated.
85
+
86
+ ### Test-Time Processing
87
+
88
+ Test audio is loaded directly (10s at 16 kHz), peak-normalized, and converted to a mel spectrogram using the same parameters. No augmentation is applied at inference.
89
+
90
+ ---
91
+
92
+ ## Training
93
+
94
+ ### Hyperparameters
95
+
96
+ | Parameter | Value |
97
+ |-----------|-------|
98
+ | Optimizer | AdamW |
99
+ | Learning rate | 1 Γ— 10⁻⁡ |
100
+ | Weight decay | 0.01 |
101
+ | Batch size | 4 |
102
+ | Gradient accumulation | 4 steps (effective batch size = 16) |
103
+ | Max epochs | 15 |
104
+ | Early stopping patience | 7 epochs |
105
+ | Loss function | CrossEntropyLoss |
106
+ | LR scheduler | ReduceLROnPlateau (factor=0.5, patience=2, min_lr=1e-7) |
107
+ | Training samples | 1,000 per epoch (generated on-the-fly) |
108
+ | Validation samples | 500 per epoch |
109
+
110
+ ### Training Strategy
111
+
112
+ - **Gradient accumulation** (4 steps) is used to simulate a larger effective batch size while fitting within GPU VRAM
113
+ - **ReduceLROnPlateau** monitors the macro F1 score and halves the learning rate after 2 epochs without improvement
114
+ - **Early stopping** triggers after 7 consecutive epochs without a new best F1 score
115
+ - Best model weights are saved to `best_ast_model.pth` whenever a new best F1 is achieved
116
+ - **WandB** logs all training metrics (train loss, val loss, F1 score, learning rate) per epoch
117
+
118
+ ### Seeds
119
+
120
+ | Seed | Value |
121
+ |------|-------|
122
+ | Data seed | 67 |
123
+ | Training seed | 1234 |
124
+ | Train/Val split seed | 42 |
125
+
126
+ ---
127
+
128
+ ## Results
129
+
130
+ | Metric | Score |
131
+ |--------|:-----:|
132
+ | **Max Validation F1 (macro)** | **0.8861** |
133
+ | **Kaggle Leaderboard Score** | **0.85708** |
134
+
135
+ ### Comparison with Other Models
136
+
137
+ | Model | Val F1 | Leaderboard |
138
+ |-------|:------:|:-----------:|
139
+ | CRNN (scratch) | 0.5800 | 0.33103 |
140
+ | EfficientNet-B0 | 0.5258 | 0.31641 |
141
+ | **AST (this model)** | **0.8861** | **0.85708** |
142
+
143
+ ### Why AST Outperforms
144
+
145
+ - **Large-scale pretraining**: The base checkpoint was pretrained on AudioSet (2M+ audio clips), providing robust audio representations
146
+ - **Longer input context**: 10s duration captures more musical structure compared to 5s for other models
147
+ - **Mel spectrogram input**: 128-bin mel spectrograms retain richer frequency detail than MFCCs
148
+ - **Self-attention**: Transformers can model long-range temporal dependencies that CNNs and even RNNs struggle with
149
+ - **Aggressive augmentation**: Stem dropout, variable gain, and variable SNR noise injection improve generalization
150
+
151
+ ---
152
+
153
+ ## Usage
154
+
155
+ ### Prerequisites
156
+
157
+ ```bash
158
+ pip install torch transformers librosa numpy pandas scikit-learn wandb
159
+ ```
160
+
161
+ ### Training
162
+
163
+ ```python
164
+ from AST_Pipeline import MusicGenreAST, train_ast
165
+
166
+ model = MusicGenreAST(num_classes=10)
167
+ train_ast(model)
168
+ # Best weights saved to best_ast_model.pth
169
+ ```
170
+
171
+ ### Inference
172
+
173
+ ```python
174
+ from AST_Pipeline import MusicGenreAST, predict
175
+
176
+ results = predict(
177
+ model_instance=MusicGenreAST(10),
178
+ model_path='best_ast_model.pth'
179
+ )
180
+ # results: list of genre strings, e.g. ['rock', 'jazz', 'blues', ...]
181
+ ```
182
+
183
+ ### Generating a Submission
184
+
185
+ ```python
186
+ import pandas as pd
187
+
188
+ submission_df = pd.read_csv('sample_submission.csv')
189
+ submission = pd.DataFrame({
190
+ "id": submission_df['id'],
191
+ "genre": results
192
+ })
193
+ submission.to_csv("submission.csv", index=False)
194
+ ```
195
+
196
+ ---
197
+
198
+ ## File Structure
199
+
200
+ ```
201
+ β”œβ”€β”€ AST_Pipeline.py # Full pipeline: dataset, model, training, prediction
202
+ β”œβ”€β”€ best_ast_model.pth # Saved model weights (best validation F1)
203
+ β”œβ”€β”€ requirements.txt # Python dependencies
204
+ └── AST_README.md # This file
205
+ ```
206
+
207
+ ### Key Classes & Functions in AST_Pipeline.py
208
+
209
+ | Name | Type | Description |
210
+ |------|------|-------------|
211
+ | `ASTAudioDataset` | Dataset | Training/validation dataset with on-the-fly stem mixing and augmentation |
212
+ | `ASTTestDataset` | Dataset | Test dataset β€” loads audio and converts to mel spectrogram |
213
+ | `MusicGenreAST` | nn.Module | Wrapper around `ASTForAudioClassification` with 10-class head |
214
+ | `build_dataset()` | Function | Builds train/val dictionaries with stratified split |
215
+ | `train_ast()` | Function | Full training loop with gradient accumulation, scheduler, early stopping, and WandB logging |
216
+ | `predict()` | Function | Loads saved weights and runs inference on the test set |
217
+
218
+ ---
219
+
220
+ ## Acknowledgements
221
+
222
+ - [MIT AST](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) β€” Pretrained Audio Spectrogram Transformer by Yuan Gong et al.
223
+ - [ESC-50](https://github.com/karolpiczak/ESC-50) β€” Environmental Sound Classification dataset used for noise augmentation
224
+ - [Weights & Biases](https://wandb.ai/) β€” Experiment tracking
225
+ - [librosa](https://librosa.org/) β€” Audio analysis and feature extraction